Full Code of shell-nlp/gpt_server for AI

main 1d266b0b2a50 cached
93 files
347.1 KB
85.6k tokens
370 symbols
1 requests
Download .txt
Showing preview only (385K chars total). Download the full file or copy to clipboard to get everything.
Repository: shell-nlp/gpt_server
Branch: main
Commit: 1d266b0b2a50
Files: 93
Total size: 347.1 KB

Directory structure:
gitextract_z12pzj5x/

├── .dockerignore
├── .github/
│   └── workflows/
│       └── docker-image.yml
├── .gitignore
├── .python-version
├── Dockerfile
├── Dockerfile.copy
├── LICENSE
├── MANIFEST.in
├── README.md
├── docker-compose-bash.yaml
├── docker-compose.yml
├── gpt_server/
│   ├── __init__.py
│   ├── cli.py
│   ├── database/
│   │   └── models/
│   │       └── process_manager.py
│   ├── model_backend/
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── hf_backend.py
│   │   ├── lmdeploy_backend.py
│   │   ├── sglang_backend.py
│   │   ├── utils.py
│   │   └── vllm_backend.py
│   ├── model_handler/
│   │   ├── __init__.py
│   │   ├── chat_template/
│   │   │   ├── get_chat_template.py
│   │   │   ├── qwen3.jinja
│   │   │   ├── qwen3_zh.jinja
│   │   │   └── qwen3vl.jinja
│   │   ├── pitch.py
│   │   ├── reasoning_parser.py
│   │   ├── tool_parser.py
│   │   └── utils.py
│   ├── model_worker/
│   │   ├── __init__.py
│   │   ├── auto.py
│   │   ├── base/
│   │   │   ├── __init__.py
│   │   │   ├── base_model_worker.py
│   │   │   └── model_worker_base.py
│   │   ├── embedding_infinity.py
│   │   ├── embedding_sentence_transformers.py
│   │   ├── embedding_v2.py
│   │   ├── embedding_vllm.py
│   │   ├── flux.py
│   │   ├── funasr.py
│   │   ├── qwen_image.py
│   │   ├── qwen_image_edit.py
│   │   ├── spark_tts.py
│   │   ├── utils.py
│   │   ├── voxcpm_tts.py
│   │   ├── wan.py
│   │   └── z_image.py
│   ├── openai_api_protocol/
│   │   ├── __init__.py
│   │   └── custom_api_protocol.py
│   ├── script/
│   │   ├── __init__.py
│   │   ├── config_example.yaml
│   │   ├── start.sh
│   │   └── stop.sh
│   ├── serving/
│   │   ├── __init__.py
│   │   ├── chat_ui.py
│   │   ├── controller.py
│   │   ├── controller_v2.py
│   │   ├── main.py
│   │   ├── openai_api_server.py
│   │   └── server_ui.py
│   ├── settings.py
│   ├── utils.py
│   └── version.py
├── pyproject.toml
├── setup.py
└── tests/
    ├── download_model.py
    ├── responses_api/
    │   ├── test_openai_responses.py
    │   ├── test_openai_responses_response_format.py
    │   ├── test_openai_responses_tool_calling.py
    │   └── test_response_vl_chat.py
    ├── sglang/
    │   └── models.py
    ├── test_chat_template.py
    ├── test_embedding_dynamic_batch.py
    ├── test_image_edit.py
    ├── test_image_gen.py
    ├── test_mteb.py
    ├── test_needle_haystack.py
    ├── test_openai_chat.py
    ├── test_openai_completion.py
    ├── test_openai_completion_response_format.py
    ├── test_openai_completion_tool_calling.py
    ├── test_openai_embedding.py
    ├── test_openai_embedding_vl.py
    ├── test_openai_moderation.py
    ├── test_openai_rerank.py
    ├── test_openai_transcriptions.py
    ├── test_openai_tts_stream.py
    ├── test_openai_vl_chat.py
    ├── test_perf.py
    ├── test_rerank.py
    └── vllm/
        ├── embedding.py
        └── models.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .dockerignore
================================================
test/
.vscode/
.venv/
__pycache__/
*.log*
*.egg-info
logs/
outputs/
data/
.env

================================================
FILE: .github/workflows/docker-image.yml
================================================
name: Docker Image CI

on:
  # release:
  #   types: 
  #     - published  # 当发布新的 release 时触发
  push:
    branches:
    - build_image # 在推送到 build_image 分支时触发构建
    - set_latest

jobs:

  build_version:
    if: github.ref  == 'refs/heads/build_image'
    runs-on: ubuntu-latest

    steps:
      # 清理磁盘空间
      - name: Clean up Docker build cache 
        run: docker system prune -af --volumes 

      - name: Free up disk space 
        run: |
          sudo rm -rf /usr/share/dotnet 
          sudo rm -rf /opt/ghc 
          sudo rm -rf "/usr/local/share/boost"
          sudo rm -rf "$AGENT_TOOLSDIRECTORY"

      # 检出代码
      - name: Checkout code
        uses: actions/checkout@v3
      # 登录 Docker Hub
      - name: Log in to Docker Hub
        run: echo "${{ secrets.DOCKER_PASSWORD }}" | docker login -u "${{ secrets.DOCKER_USERNAME }}" --password-stdin
      # 从 pyproject.toml 中抽取版本信息
      - name: Extract version
        id: get_version
        run: |
          # 使用 grep 和 sed 从 pyproject.toml 中提取版本
          version=$(grep -Po '(?<=^version = ")[^"]*' pyproject.toml)
          echo "VERSION=$version" >> $GITHUB_ENV

      # 构建 Docker 镜像
      - name: Build Docker image
        run: |
          docker build -f Dockerfile -t ${{ secrets.DOCKER_USERNAME }}/gpt_server:${{ env.VERSION }} .
          # docker tag ${{ secrets.DOCKER_USERNAME }}/gpt_server:${{ env.VERSION }} ${{ secrets.DOCKER_USERNAME }}/gpt_server:latest
      # 推送镜像到 Docker Hub
      - name: Push Docker image
        run: |
          docker push ${{ secrets.DOCKER_USERNAME }}/gpt_server:${{ env.VERSION }}
          # docker push ${{ secrets.DOCKER_USERNAME }}/gpt_server:latest
  tag_latest:
    if: github.ref  == 'refs/heads/set_latest'
    runs-on: ubuntu-latest 
    steps:
      # 清理磁盘空间1
      - name: Maximize build space
        uses: easimon/maximize-build-space@master
        with:
          root-reserve-mb: 5120
          swap-size-mb: 1024
          remove-dotnet: 'true'
          remove-android: 'true'
          remove-haskell: 'true'
      # 清理磁盘空间2
      - name: Clean up Docker build cache 
        run: docker system prune -af --volumes 

      - name: Free up disk space 
        run: |
          sudo rm -rf /usr/share/dotnet 
          sudo rm -rf /opt/ghc 
          sudo rm -rf "/usr/local/share/boost"
          sudo rm -rf "$AGENT_TOOLSDIRECTORY"
          
      - name: Checkout code 
        uses: actions/checkout@v3
 
      - name: Log in to Docker Hub
        run: echo "${{ secrets.DOCKER_PASSWORD }}" | docker login -u "${{ secrets.DOCKER_USERNAME }}" --password-stdin
 
      - name: Extract version 
        id: get_version 
        run: |
          version=$(grep -Po '(?<=^version = ")[^"]*' pyproject.toml) 
          echo "VERSION=$version" >> $GITHUB_ENV 
      # 安装 skopeo
      - name: Install skopeo
        run: |
          sudo apt-get update
          sudo apt-get install -y skopeo
      # 5. (新) 使用 skopeo 高效地为远程镜像打标签
      # 这条命令直接在 Docker Hub 上操作,不会下载任何镜像层
      - name: Retag remote image without pulling
        run: |
          skopeo copy \
            docker://${{ secrets.DOCKER_USERNAME }}/gpt_server:${{ env.VERSION }} \
            docker://${{ secrets.DOCKER_USERNAME }}/gpt_server:latest

      # - name: Pull and tag latest 
      #   run: |
      #     # 拉取已存在的版本镜像
      #     docker pull ${{ secrets.DOCKER_USERNAME }}/gpt_server:${{ env.VERSION }}
      #     # 仅添加latest标签并推送
      #     docker tag ${{ secrets.DOCKER_USERNAME }}/gpt_server:${{ env.VERSION }} ${{ secrets.DOCKER_USERNAME }}/gpt_server:latest
      #     docker push ${{ secrets.DOCKER_USERNAME }}/gpt_server:latest

================================================
FILE: .gitignore
================================================
.vscode/
__pycache__/
*.log*
*.egg-info
test/
logs/
outputs/
data/
.venv/
config.yaml
.env
*_test.yaml

================================================
FILE: .python-version
================================================
3.11


================================================
FILE: Dockerfile
================================================
# FROM docker.1ms.run/506610466/cuda:12.2.2-runtime-ubuntu20.04-uv
FROM 506610466/cuda:12.2.2-devel-ubuntu20.04-uv
# 从基础镜像开始构建,加快构建速度
# FROM 506610466/gpt_server:base
RUN apt-get update -y && apt-get install -y git numactl build-essential && rm -rf /var/lib/apt/lists/*
COPY ./ /gpt_server
WORKDIR /gpt_server
# RUN uv sync && uv cache clean
ENV UV_HTTP_TIMEOUT=120 CUDA_HOME=/usr/local/cuda-12.2
ENV PATH=$CUDA_HOME/bin:$PATH 
ENV LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH
RUN uv venv --seed && uv sync -v && uv cache clean && \
    echo '[[ -f .venv/bin/activate ]] && source .venv/bin/activate' >> ~/.bashrc
ENV PATH=/gpt_server/.venv/bin:$PATH

CMD ["/bin/bash"]

================================================
FILE: Dockerfile.copy
================================================
FROM docker.1ms.run/506610466/gpt_server:latest 

COPY ./ /gpt_server

WORKDIR /gpt_server

CMD ["/bin/bash"]

================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: MANIFEST.in
================================================
include gpt_server/script/*.yaml

================================================
FILE: README.md
================================================
<div align="center">

<a href="https://github.com/shell-nlp/gpt_server"><img src="assets/logo.png" width="252" height="116" alt="gpt_server logo"></a>

# GPT Server
[![License][license-shield]][license-url]
[![Stars][stars-shield]][stars-url]
[![Forks][forks-shield]][forks-url]
[![Docker pulls][docker-pulls]][docker-pulls]
[![CI Status][ci-shield]][ci-url]
[![issue resolution][closed-issues-shield]][closed-issues-url]

</div>

本项目依托fastchat的基础能力来提供**openai server**的能力.

1. 支持**Chat**、**Embedding**、**ReRanker**、**text-moderation(文本审核,分类)**、**ASR**、**TTS(支持声音克隆)**、**SD(Stable Diffusion,文生图、文生视频、图片编辑、)** 模型的 **openai**规范 接口服务。
2. 支持**HF**、**vLLM**、**LMDeploy**和**SGLang** 多种加速推理后端引擎。
3. 多个模型共用**openai server**的同一个端口进行调用,自动进行模型调度。

如果 GPT Server 对您有帮助,欢迎留下一个 ⭐ Star!
<br>

## ✨ 功能亮点
|     | 功能          | 说明                                                                |
|-----|-------------|-------------------------------------------------------------------|
| 🎨  | **OpenAI服务接口**     | 支持 `OpenAI` 服务接口规范,兼容所有支持 OpenAI的项目工程                                          |
| 💎  | **支持 `Responses API` 接口**     | 全球首个兼容 `OpenAI`  `Responses API` 接口                |
| 🚀  | **多后端引擎推理** | 支持 `vLLM`、`SGLang`、`LMDeploy`、`HF`多种高性能推理引擎 |
| 🎯  | **Embedding/Reranker** | 支持所有兼容`Sentence_Transformers`的语义向量或重排模型,支持了Infinity后端,**Embedding**推理速度大于onnx/tensorrt,支持动态组批 |
| 🎛️ | **Text-moderation(文本审核,分类)**   | 支持`OpenAI` 服务接口规范的文本审核,分类                                                |
| 📱  | **ASR(语音转文本)**    | 支持基于`FunASR`的ASR模型                                        |
| 🔊  | **TTS(文本转语音)**   | 支持基于`SparkTTS`的TTS模型,支持基于`vLLM`、`SGLang`后端对齐加速,`RTF<<1`,支持流式音频流输出                                          |
| 🖌️  | **SD(Stable Diffusion,文生图)**    | 支持基于`diffusers`的 `文生图` 模型                                        |
| 🏔️  | **SD(Stable Diffusion,图片编辑)**    | 支持基于`diffusers`的 `图片编辑` 模型                                        |
| 🔄  | **支持LM/VL模型**  | 支持多种大语言模型或多模态语言模型                                              |
| 🎭  | **推理服务性能测试**   | 基于`Evalscope`实现`Throughput`、`TTFT`、`TPOT`等服务性能指标                                                  |

<br>

### 其它特性
- 支持了`cohere`库接口规范的 /v1/rerank 接口,在dify中可用。
- 扩展了`OpenAI`库,实现Reranker模型(rerank, /v1/rerank)。(代码样例见gpt_server/tests/test_openai_rerank.py)
- 支持了`OpenAI`库的文本审核模型接口(text-moderation, /v1/moderations)。(代码样例见gpt_server/tests/test_openai_moderation.py)
- 支持了`OpenAI`库的TTS模型接口(tts, /v1/audio/speech)(代码样例见gpt_server/tests/test_openai_tts_stream.py)
- 支持了`OpenAI`库的ASR模型接口(asr, /v1/audio/transcriptions),基于fanasr后端(代码样例见gpt_server/tests/test_openai_transcriptions.py)
- 支持了`OpenAI`库的SD,文生图模型接口(sd, /v1/images/generations),基于diffusers后端(代码样例见gpt_server/tests/test_image_gen.py)
- 支持了`OpenAI`库的SD,文生图模型接口(sd, /v1/images/edits),基于diffusers后端(代码样例见gpt_server/tests/test_image_edit.py)


## 📘 配置文档 


- **[GPT Server - DeepWiki文档(可直接AI提问使用方式)](https://deepwiki.com/shell-nlp/gpt_server "deepwiki文档")**
<br>

- **[配置详细说明](https://blog.csdn.net/q506610466/article/details/151360406 "详细配置说明")**
<br>

- [配置文件样例](https://github.com/shell-nlp/gpt_server/blob/main/gpt_server/script/config_example.yaml "配置文件")

## 🎉 最新进展
<details open>
<summary><b>2025</b></summary>
 
```plaintext
2025-11-30 支持了 z-image 文生图 模型
2025-11-16 支持了 jinaai/jina-reranker-v3 模型
2025-10-25 支持了 qwen_image 文生图模型
2025-9-7   支持了 文本编辑模型 (代码样例见gpt_server/tests/test_image_edit.py)
2025-8-8   初步支持了 embedding 的 vllm 加速
2025-6-17  支持了 jina-reranker-m0 全球首个支持多模态多语言的重排模型
2025-6-12  支持了 文生图模型 flux (代码样例见gpt_server/tests/test_image_gen.py)
2025-6-6   支持了 bge-vl 系列 (代码样例见gpt_server/tests/test_openai_embedding_vl.py)
2025-6-6   支持了 ritrieve_zh_v1
2025-4-29  支持了 Qwen3
2025-4-24  支持了 Spark-TTS后端的 TTS
2025-4-14  支持了 SGLang后端以及部分VL模型
2025-4-2   支持了 OpenAI的ASR接口 /v1/audio/transcriptions
2025-4-1   支持了 internvl2.5模型
2025-2-9   支持了 QVQ
```
</details>

<details close>
<summary><b>2024</b></summary>
 
```plaintext
2024-12-22 支持了 tts, /v1/audio/speech TTS模型
2024-12-21 支持了 text-moderation, /v1/moderations 文本审核模型 
2024-12-14 支持了 phi-4
2024-12-7  支持了 /v1/rerank 接口
2024-12-1  支持了 QWQ-32B-Preview
2024-10-15 支持了 Qwen2-VL
2024-9-19  支持了 minicpmv 模型
2024-8-17  支持了 vllm/hf 后端的 lora 部署
2024-8-14  支持了 InternVL2 系列多模态模型
2024-7-28  支持了 embedding/reranker 的动态组批加速(infinity后端, 比onnx/tensorrt更快)
2024-7-19  支持了多模态模型 glm-4v-gb 的LMDeploy PyTorch后端
2024-6-22  支持了 Qwen系列、ChatGLM系列 function call (tools) 能力
2024-6-12  支持了 qwen-2
2024-6-5   支持了 Yinka、zpoint_large_embedding_zh 嵌入模型
2024-6-5   支持了 glm4-9b系列(hf和vllm)
2024-4-27  支持了 LMDeploy 加速推理后端
2024-4-20  支持了 llama-3
2024-4-13  支持了 deepseek
2024-4-4   支持了 embedding模型 acge_text_embedding
2024-3-9   支持了 reranker 模型 ( bge-reranker,bce-reranker-base_v1)
2024-3-3   支持了 internlm-1.0 ,internlm-2.0
2024-3-2   支持了 qwen-1.5 0.5B, 1.8B, 4B, 7B, 14B, and 72B
2024-2-4   支持了 vllm 实现
2024-1-6   支持了 Yi-34B
```
</details>

<details close>
<summary><b>2023</b></summary>
 
```plaintext
2023-12-31 支持了 qwen-7b, qwen-14b
2023-12-30 支持了 all-embedding(理论上支持所有的词嵌入模型)
2023-12-24 支持了 chatglm3-6b 
```
</details>

## 🧭 路线

* [X] 支持HF后端
* [X] 支持vLLM后端
* [X] 支持LMDeploy后端
* [X] 支持SGLang后端
* [X] 支持 文本转语音 TTS 模型
* [X] 支持 语音转文本 ASR 模型
* [X] 支持 文本审核 模型
* [X] 支持 function call 功能 (tools)(Qwen系列、ChatGLM系列已经支持,后面有需求再继续扩展)
* [X] 支持多模态模型
* [X] 支持Embedding模型动态组批(实现方式:infinity后端)
* [X] 支持Reranker模型动态组批(实现方式:infinity后端)
* [X] 可视化启动界面(不稳定,对开发人员来说比较鸡肋,后期将弃用!)
* [X] 支持 文生图 模型
* [X] 支持 图片编辑 模型
* [X] 支持 Responses API



## ⚙️ 快速开始

### 1. 配置python环境

#### 1.1 uv 方式 安装 (推荐,迄今最优秀的 库 管理工具, 性能和易用性远高于 pip、conda、poetry等,各大优秀开源项目都在使用。)

```bash
# 安装 uv 
pip install uv -U # 或查看教程 https://docs.astral.sh/uv/getting-started/installation/#standalone-installer
# uv venv --seed # (可选)创建 uv 虚拟环境,并设置seed
uv sync
source .venv/bin/activate # 激活 uv 环境
```

#### 1.2 conda  方式 安装(后期将弃用,可选)

```bash
# 1. 创建conda 环境
conda create -n gpt_server python=3.11

# 2. 激活conda 环境
conda activate gpt_server

# 3. 安装仓库(一定要使用 install.sh 安装,否则无法解决依赖冲突)
bash install.sh
```

### 2. 修改启动配置文件

#### 2.1 复制样例配置文件:
**配置文件的详细说明信息位于:[config_example.yaml](https://github.com/shell-nlp/gpt_server/blob/main/gpt_server/script/config_example.yaml "配置文件")**

```bash
# 进入script目录
cd gpt_server/script
# 复制样例配置文件
cp config_example.yaml config.yaml
```

### 3. 启动服务
#### 3.1 命令启动

```bash
uv run gpt_server/serving/main.py
```
或者
```bash
sh gpt_server/script/start.sh
```
或者
```bash
python gpt_server/serving/main.py
```

#### 3.2 Docker启动

##### 3.2.0 拉取Docker Hub镜像
```bash
docker pull 506610466/gpt_server:latest # 如果拉取失败可尝试下面的方式
# 如果国内无法拉取docker镜像,可以尝试下面的国内镜像拉取的方式(不保证国内镜像源一直可用)
docker pull docker.xuanyuan.me/506610466/gpt_server:latest
```
##### 3.2.1 直接使用Docker命令直接启动
```bash
docker run -d \
  --name gpt_server \
  --restart always \
  --shm-size 32g \
  --network host \
  -v your_model_path/:your_model_path/ \
  -v your_config_path/config.yaml:/gpt_server/gpt_server/script/config.yaml \
  --gpus all \
  docker.1ms.run/506610466/gpt_server:latest  \
  python gpt_server/serving/main.py  
```

将`your_model_path`替换为你的模型路径,且要和`config.yaml`中配置的路径一致
将`your_config_path`替换为你`config.yaml`文件的路径


##### 3.2.2 手动构建镜像并使用Docker Compose 启动(可选)

```bash
docker-compose  -f "docker-compose.yml" up -d --build gpt_server
```

<details close>
<summary> <b> 3.3 可视化UI方式启动服务(有Bug,已弃用,欢迎大佬优化代码)</b></summary>

#### 3.3 可视化UI方式启动服务(可选,有Bug,不建议使用,欢迎大佬优化代码)

```bash
cd gpt_server/serving
streamlit run server_ui.py
```

##### 3.3.1 Server UI界面:

![server_ui_demo.png](assets/server_ui_demo.png)

</details>

### 4. 使用 openai 库 进行调用

**见 gpt_server/tests 目录 样例测试代码:
https://github.com/shell-nlp/gpt_server/tree/main/tests**

### 5. 使用Chat UI

```bash
cd gpt_server/gpt_server/serving
streamlit run chat_ui.py
```

Chat UI界面:

![chat_ui_demo.png](assets/chat_ui_demo.png)



## ⚡ 支持的模型以及推理后端

**推理速度:** LMDeploy TurboMind > SGLang > vllm > LMDeploy PyTorch > HF

### 推理后端官方支持模型情况


[LMDeploy](https://lmdeploy.readthedocs.io/en/latest/supported_models/supported_models.html) 

[vLLM](https://docs.vllm.ai/en/latest/models/supported_models.html) 

[SGLang](https://docs.sglang.ai/supported_models/generative_models.html) 

#### 注意:
- **现可以通过在 `config.yaml`中 设置 `model_type: auto`** 支持所有vllm/sglang/lmdeploy 当前版本已经支持的大语言模型和多模态语言模型。

- 下面的项目兼容表未来将移除或者重构,没有在表中的模型也可能兼容,实际情况情参考官方。

### **LLM**

|   Models / BackEnd    | model_type |  HF   | vllm  | LMDeploy TurboMind | LMDeploy PyTorch | SGLang |
| :-------------------: | :--------: | :---: | :---: | :----------------: | :--------------: | :----: |
|      chatglm4-9b      |  chatglm   |   √   |   √   |         √          |        √         |   √    |
|      chatglm3-6b      |  chatglm   |   √   |   √   |         ×          |        √         |   √    |
|   Qwen-1.0--3.0       |    qwen    |   √   |   √   |         √          |        √         |   √    |
|        Yi-34B         |     yi     |   √   |   √   |         √          |        √         |   √    |
|    Internlm-1.0--2.0  |  internlm  |   √   |   √   |         √          |        √         |   √    |
|       Deepseek        |  deepseek  |   √   |   √   |         √          |        √         |   √    |
|        Llama-3        |   llama    |   √   |   √   |         √          |        √         |   √    |
|      Baichuan-2       |  baichuan  |   √   |   √   |         √          |        √         |   √    |
|        QWQ-32B        |    qwen    |   √   |   √   |         √          |        √         |   √    |
|         Phi-4         |    phi     |   √   |   √   |         ×          |        ×         |   √    |
### **VLM** (视觉大模型榜单 https://rank.opencompass.org.cn/leaderboard-multimodal)

| Models / BackEnd | model_type |  HF   | vllm  | LMDeploy TurboMind | LMDeploy PyTorch | SGLang |
| :--------------: | :--------: | :---: | :---: | :----------------: | :--------------: | :----: |
|    glm-4v-9b     |  chatglm   |   ×   |   ×   |         ×          |        √         |   ×    |
|    InternVL2     |  internvl  |   ×   |   ×   |         √          |        √         |   ×    |
|InternVL2.5--3.5  |  internvl  |   ×   |   ×   |         √          |        √         |   ×    |
|  MiniCPM-V-2.6   |  minicpmv  |   ×   |   √   |         √          |        ×         |   ×    |
|  MiniCPM-V-4.5   |  minicpmv  |   ×   |   √   |         ×          |        ×         |   ×    |
|     Qwen-VL 2.0--3.0     |    qwen    |   ×   |   √   |         √         |        √         |   √    |
|       QVQ        |    qwen    |   ×   |   √   |         √          |        √         |   √    |
<br>

### Embedding/Rerank/Classify模型

**原则上支持所有的Embedding/Rerank/Classify模型**

**推理速度:** infinity > sentence_transformers

以下模型经过测试可放心使用:

| Models / BackEnd                                                                    | sentence_transformers  | infinity | vllm|
| ----------------------------------------------------------------------------------- | --------------- | -------------- |----------- |
| bge-m3                                                                              | √   | √        |√        |
| bge-embedding                                                                       | √   | √        |√        |
| bce-embedding                                                                       | √   | √        |√        |
| puff                                                                                | √   | √        |√        |
| piccolo-base-zh-embedding                                                           | √   | √        |√        |
| acge_text_embedding                                                                 | √   | √        |√        |
| Yinka                                                                               | √   | √        |√        |
| zpoint_large_embedding_zh                                                           | √   | √        |√        |
| xiaobu-embedding                                                                    | √   | √        |√        |
| Conan-embedding-v1                                                                  | √   | √        |√        |
| qwen3-embedding                                                                     | √   | √        |√        |
| ritrieve_zh_v1                                                                      | √   | √        |√        |
| jina-embeddings-v3                                                                  | √   | √        |√        |
| KoalaAI/Text-Moderation(文本审核/多分类,审核文本是否存在暴力、色情等)                | ×   | √         |×        |
| protectai/deberta-v3-base-prompt-injection-v2(提示注入/2分类,审核文本为提示注入)    | ×   | √         |×        |
| bge-vl                                                                              | √   | ×        |×        |
| jina-reranker-m0                                                                    | √   | ×        |×        |
| bge-reranker                                                                        | √   | √        |×        |
| bce-reranker                                                                        | √   | √        |×        |
| jina-reranker-v3                                                                     | √   | ×        |×        |

目前 **ritrieve_zh_v1** C-MTEB榜单排行第一(MTEB: https://huggingface.co/spaces/mteb/leaderboard)

<br>

### **ASR** (支持FunASR非实时模型 https://github.com/modelscope/FunASR/blob/main/README_zh.md)
目前只测试了SenseVoiceSmall模型(性能最优的),其它模型的支持情况只是从官方文档中拷贝过来,不一定可以正常使用,欢迎测试/提issue。

|    Models / BackEnd    | model_type |
| :--------------------: | :--------: |
|    SenseVoiceSmall     |   funasr   |
|     paraformer-zh      |   funasr   |
|     paraformer-en      |   funasr   |
|      conformer-en      |   funasr   |
|    Whisper-large-v3    |   funasr   |
| Whisper-large-v3-turbo |   funasr   |
|       Qwen-Audio       |   funasr   |
|    Qwen-Audio-Chat     |   funasr   |

<br>

### **TTS** 模型

| Models / BackEnd | model_type |
| :--------------: | :--------: |
|    Spark-TTS     | spark_tts  |


<br>

### **文生图** 模型
[Flux 模型地址](https://huggingface.co/black-forest-labs/FLUX.1-dev)
<br>
[Z-Image-Turbo 模型地址](https://modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)
<br>
[Qwen-Image 系列模型地址](https://modelscope.cn/models/Qwen/Qwen-Image-2512)


| Models / BackEnd | model_type |
| :--------------: | :--------: |
|    flux     | flux  |
|    qwen_image     | qwen_image  |
|    z_image     | z_image  |

<br>

### **图片编辑** 模型
[Qwen-Image-Edit 模型地址](https://huggingface.co/Qwen/Qwen-Image-Edit)
<br>
[Qwen-Image-Edit-2511 模型地址](https://modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)

| Models / BackEnd | model_type |
| :--------------: | :--------: |
|Qwen-Image-Edit   | qwen_image_edit  |

<br>

## 🏗️ 架构

![gpt_server_archs.png](assets/gpt_server_archs.png)

## 🤝 致谢
- [FastChat](https://github.com/lm-sys/FastChat) 
- [vLLM](https://github.com/vllm-project/vllm)  
- [LMDeploy ](https://github.com/InternLM/lmdeploy)
- [SGLang ](https://github.com/sgl-project/sglang)
- [infinity](https://github.com/michaelfeil/infinity) 
- [FlashTTS](https://github.com/HuiResearch/FlashTTS) 

## 📲 与我联系(会邀请进入交流群)

![wechat.png](assets/wechat.png)

## 🌟 Star History

[![Star History Chart](https://api.star-history.com/svg?repos=shell-nlp/gpt_server&type=Date)](https://star-history.com/#shell-nlp/gpt_server&Date)

[open-issues-url]: https://github.com/shell-nlp/gpt_server/issues
[open-issues-shield]: https://img.shields.io/github/issues-raw/shell-nlp/gpt_server
[closed-issues-shield]: https://img.shields.io/github/issues-closed-raw/shell-nlp/gpt_server
[closed-issues-url]: https://github.com/shell-nlp/gpt_server/issues

[forks-url]: https://github.com/shell-nlp/gpt_server/network/members
[forks-shield]: https://img.shields.io/github/forks/shell-nlp/gpt_server?color=9cf
[stars-url]: https://github.com/shell-nlp/gpt_server/stargazers
[stars-shield]: https://img.shields.io/github/stars/shell-nlp/gpt_server?color=yellow
[license-url]: https://github.com/shell-nlp/gpt_server/blob/main/LICENSE
[license-shield]: https://img.shields.io/github/license/shell-nlp/gpt_server
[docker-pulls]: https://img.shields.io/docker/pulls/506610466/gpt_server
[ci-shield]: https://github.com/shell-nlp/gpt_server/actions/workflows/docker-image.yml/badge.svg
[ci-url]: https://github.com/shell-nlp/gpt_server/actions/workflows/docker-image.yml


================================================
FILE: docker-compose-bash.yaml
================================================
# 这容器的目的是为了方便直接在容器内使用项目的用户
version: '3.8'
services:
  gpt_server_bash:
    # ------ 从项目构建最新代码镜像 ------
    # build:
    #   context: .
    #   dockerfile: Dockerfile.copy
    # image: gpt_server:bash
    image: docker.1ms.run/506610466/gpt_server:latest
    container_name: bash
    # ------ 从项目构建最新代码镜像 ------
    # image: docker.1ms.run/506610466/gpt_server:latest # 如果只是用docker hub发布的镜像,则去掉这个注释,将上面从项目构建最新代码镜像的注释掉
    command: /bin/bash
    tty: true              # 对应 -it 的交互模式
    stdin_open: true       # 允许标准输入
    network_mode: "host"   # --network=host
    volumes:
      - ./gpt_server:/gpt_server/gpt_server # 将最新代码直接映射到容器中,以运行最新的代码
      - /home/dev/model/:/home/dev/model/ # 映射模型路径
    shm_size: "100gb"      # --shm-size 100gb
    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia
              count: all
              capabilities: [ gpu ]
    ulimits:               # --ulimit memlock=-1
      memlock:
        soft: -1
        hard: -1

================================================
FILE: docker-compose.yml
================================================
version: '3'
services:
  gpt_server:
    # 构建
    # 为什么每次构建更好?而不是直接使用 image: docker.1ms.run/506610466/gpt_server:latest
    # 如果使用 volumes 映射的方式,虽然启动更快,但会影响已启动容器的runtime稳定性,物理机修改的代码会在容器runtime中立马生效。
    build:
      context: .
      dockerfile: Dockerfile.copy
    # image: docker.1ms.run/506610466/gpt_server:latest
    image: gpt_server:latest_
    container_name: gpt_server
    shm_size: '32g' # 设置共享内存为4GB 
    restart: always
    # network_mode: host
    ports:
      - 8082:8082
      - 21001:21001
    environment:
      - TZ:Asia/Shanghai  # 设置中国时区
    volumes:
      - ./gpt_server:/gpt_server/gpt_server # 将最新代码以及配置直接映射到容器中,以运行最新的代码
      - /home/dev/model/:/home/dev/model/ # 映射模型路径
    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia
              # device_ids: [ '0', '1', '2', '3' ]
              count: all
              # count: 2  # 两种方式
              capabilities: [ gpu ]
    command: python gpt_server/serving/main.py


================================================
FILE: gpt_server/__init__.py
================================================


================================================
FILE: gpt_server/cli.py
================================================
import subprocess
import os
import typer

app = typer.Typer()
root_dir = os.path.dirname(__file__)
root_dir = os.path.abspath(root_dir)
chat_ui_path = os.path.join(root_dir, "serving", "chat_ui.py")
server_ui_path = os.path.join(root_dir, "serving", "server_ui.py")


@app.command(help="启动 GPT Server UI")
def ui(
    server: bool = typer.Option(False, help="启动服务UI界面"),
    chat: bool = typer.Option(False, help="启动问答UI界面"),
):
    if server:
        cmd = f"streamlit run {server_ui_path}"
        subprocess.run(cmd, shell=True)
    if chat:
        cmd = f"streamlit run {chat_ui_path}"
        subprocess.run(cmd, shell=True)


def main():
    app()


if __name__ == "__main__":
    main()


================================================
FILE: gpt_server/database/models/process_manager.py
================================================
"""暂时没有使用此代码"""

from typing import List, Dict, Optional, Any
from multiprocessing import Process
from sqlmodel import SQLModel, Field, create_engine, Session, select
from datetime import datetime
import json
from uuid import uuid4


# 数据库模型
class ProcessRecord(SQLModel, table=True):
    id: int | None = Field(default=None, primary_key=True, description="主键ID")
    pid: int | None = Field(default=None, description="进程ID")
    args: str = Field(default="", description="进程参数")
    status: str = Field(
        default="created", description="进程状态"
    )  # created, started, stopped
    created_at: datetime = Field(default_factory=datetime.now, description="创建时间")
    started_at: Optional[datetime] = Field(default=None, description="启动时间")
    stopped_at: Optional[datetime] = Field(default=None, description="停止时间")


class ProcessManager:
    def __init__(self, write_db: bool = False, db_url: str = "sqlite:///processes.db"):
        """进程管理类

        Parameters
        ----------
        write_db : bool, optional
            是否将进程信息写入到数据库, by default False
        db_url : str, optional
            数据库的连接 url, by default "sqlite:///processes.db"
        """
        self.processes: List[Dict[Process, dict]] | None = []
        self.write_db = write_db
        if self.write_db:
            self.engine = create_engine(db_url)
            # 创建表
            SQLModel.metadata.create_all(self.engine)

    def add_process(
        self,
        target,
        args=(),
    ):
        p = Process(target=target, args=args)
        process_id = uuid4().int & ((1 << 64) - 1)
        self.processes.append({p: {"args": args, "process_id": process_id}})
        if self.write_db:
            # 记录到数据库
            with Session(self.engine) as session:

                process_record = ProcessRecord(
                    id=process_id,
                    pid=None,
                    args=json.dumps(args, ensure_ascii=False),
                    status="created",
                )
                session.add(process_record)
                session.commit()
                session.refresh(process_record)

    def start_all(self):
        for process in self.processes:
            for _process, process_info in process.items():
                _process.start()
                process_info["pid"] = _process.pid
                if self.write_db:
                    process_id = process_info["process_id"]
                    # 更新数据库记录
                    with Session(self.engine) as session:
                        # 根据PID查找记录(这里简化处理,实际可能需要更好的标识)
                        statement = select(ProcessRecord).where(
                            ProcessRecord.id == process_id
                        )
                        result = session.exec(statement)
                        process_record = result.first()
                        if process_record:
                            process_record.pid = _process.pid
                            process_record.status = "started"
                            process_record.started_at = datetime.now()
                            session.add(process_record)
                            session.commit()
                            session.refresh(process_record)

    def join_all(self):
        for process in self.processes:
            for _process, process_info in process.items():
                _process.join()
                if self.write_db:
                    process_id = process_info["process_id"]
                    # 更新数据库记录为完成状态
                    with Session(self.engine) as session:
                        statement = select(ProcessRecord).where(
                            ProcessRecord.id == process_id
                        )
                        results = session.exec(statement)
                        record = results.first()
                        if record:
                            record.status = "finished"
                            record.finished_at = datetime.now()
                            session.add(record)
                            session.commit()


================================================
FILE: gpt_server/model_backend/__init__.py
================================================


================================================
FILE: gpt_server/model_backend/base.py
================================================
from abc import ABC, abstractmethod
from typing import Any, Dict


class ModelBackend(ABC):
    @abstractmethod
    def stream_chat(self, params: Dict[str, Any]):
        pass

    def shutdown(self):
        pass


================================================
FILE: gpt_server/model_backend/hf_backend.py
================================================
from typing import Any, Dict
import torch
import json
from peft import PeftModel
from transformers import TextIteratorStreamer, PreTrainedTokenizer
from transformers.generation.logits_process import LogitsProcessorList
from threading import Thread
from gpt_server.model_backend.base import ModelBackend
from gpt_server.model_backend.utils import (
    InvalidScoreLogitsProcessor,
    StoppingCriteriaList,
    StopAtSpecificTokenCriteria,
    XgrammarLogitsProcessor,
)
import asyncio
from loguru import logger
from gpt_server.settings import get_model_config

invalid_score_processor = InvalidScoreLogitsProcessor()


class NoneContextManager:
    def __enter__(self):
        pass

    def __exit__(self, exc_type, exc_val, exc_tb):
        return True


class HFBackend(ModelBackend):
    def __init__(self, tokenizer: PreTrainedTokenizer, model: torch.nn.Module) -> None:
        model_config = get_model_config()
        self.model = model
        self.tokenizer = tokenizer
        self.xgrammar_processor = XgrammarLogitsProcessor(tokenizer)
        self.lora_requests = []
        lora = model_config.lora
        if lora:
            lora_dict: dict = json.loads(lora)
            for i, (lora_name, lora_path) in enumerate(lora_dict.items()):
                self.lora_requests.append(
                    dict(
                        lora_name=lora_name,
                        lora_int_id=i,
                        lora_local_path=lora_path,
                    )
                )
                if i == 0:
                    self.model = PeftModel.from_pretrained(
                        model=model,
                        model_id=lora_path,
                        adapter_name=lora_name,
                    )
                    continue
                self.model.load_adapter(model_id=lora_path, adapter_name=lora_name)

    def shutdown(self):
        logger.info("hf后端退出")

    async def stream_chat(self, params: Dict[str, Any]):
        # params 已不需要传入 prompt
        messages = params["messages"]
        chat_template = params.get("chat_template", None)
        tools = params.get("tools", None)
        enable_thinking = bool(params.get("enable_thinking", True))
        prompt = self.tokenizer.apply_chat_template(
            messages,
            chat_template=chat_template,
            tokenize=False,
            add_generation_prompt=True,
            tools=tools,
            enable_thinking=enable_thinking,
        )
        logger.info(f"prompt:\n{prompt}")
        temperature = float(params.get("temperature", 0.8))
        top_p = float(params.get("top_p", 0.8))
        max_new_tokens = int(params.get("max_new_tokens", 512))
        # top_k = params.get("top_k", -1.0)
        # TODO ValueError: The following `model_kwargs` are not used by the model: ['presence_penalty', 'frequency_penalty'] (note: typos in the generate arguments will also show up in this list)
        # presence_penalty = float(params.get("presence_penalty", 0.0))
        # frequency_penalty = float(params.get("frequency_penalty", 0.0))
        stop = params.get("stop", [])  # 停止的 token
        input_ids = params.get("input_ids", None)
        if input_ids is None:
            input_ids = self.tokenizer([prompt], return_tensors="pt").input_ids
        stop_words_ids = params.get("stop_words_ids", [])
        if temperature <= 1e-5:
            top_p = 1.0
            temperature = 0.01

        stopping_criteria = StoppingCriteriaList()  # 停止条件
        stop_specific_token_criteria = StopAtSpecificTokenCriteria(
            token_id_list=stop_words_ids
        )
        stopping_criteria.append(stop_specific_token_criteria)
        logits_processor = LogitsProcessorList([invalid_score_processor])
        streamer = TextIteratorStreamer(
            self.tokenizer,
            skip_prompt=True,
            decode_kwargsl={"skip_special_tokens": True},
        )
        # TODO
        # ---- 支持 response_format,但是官方对BPE分词器的支持仍然太差 ----
        response_format = params["response_format"]
        if response_format is not None:
            if response_format["type"] == "json_object":
                xgrammar_processor = (
                    self.xgrammar_processor.get_json_grammar_processor()
                )
                logits_processor.append(xgrammar_processor)

            elif response_format["type"] == "json_schema":
                json_schema = response_format["json_schema"]
                assert json_schema is not None
                guided_json = json_schema["schema"]
                xgrammar_processor = self.xgrammar_processor.get_json_schema_processor(
                    schema=json.dumps(guided_json)
                )
                logits_processor.append(xgrammar_processor)
            elif response_format["type"] == "text":
                pass

        # ---- 支持 response_format,但是官方对BPE分词器的支持仍然太差 ----
        generation_kwargs = dict(
            input_ids=input_ids.to(self.model.device),
            streamer=streamer,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            logits_processor=logits_processor,
            stopping_criteria=stopping_criteria,
            # top_k=top_k,
            # presence_penalty=presence_penalty,
            # frequency_penalty=frequency_penalty,
        )
        use_lora = False
        for lora in self.lora_requests:
            if params["model"] == lora["lora_name"]:
                self.model.set_adapter(lora["lora_name"])
                use_lora = True
                break
        context_manager = NoneContextManager()
        if not use_lora and self.lora_requests:
            context_manager = self.model.disable_adapter()
        with context_manager:
            thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
            thread.start()
        prompt_tokens = len(input_ids.tolist()[0])
        completion_tokens = 0
        stop_flag = False
        try:
            current_text = ""
            previous_text = ""
            previous_token_ids = []
            current_token_ids = []
            delta_token_ids = []
            for new_text in streamer:
                for stop_word in stop:
                    if stop_word in new_text:
                        idx = new_text.rfind(stop_word)
                        stop_flag = True
                        print(
                            "********** 停止的单词为:",
                            stop_word,
                            "in",
                            new_text,
                            "**********",
                        )
                        new_text = new_text[:idx]
                        break
                current_text = current_text + new_text
                completion_tokens += 1
                usage = {
                    "prompt_tokens": prompt_tokens,
                    "completion_tokens": completion_tokens,
                    "total_tokens": prompt_tokens + completion_tokens,
                }
                ret = {
                    "text": new_text,
                    "error_code": 0,
                    "usage": usage,
                }
                yield ret
                if stop_flag:
                    break
                # 用来解决输出卡顿的问题
                await asyncio.sleep(0.02)
            logger.info(current_text)
        except asyncio.CancelledError as e:
            stop_specific_token_criteria.stop = True


================================================
FILE: gpt_server/model_backend/lmdeploy_backend.py
================================================
import os
import sys
from lmdeploy import (
    GenerationConfig,
    TurbomindEngineConfig,
    PytorchEngineConfig,
)
from lmdeploy.serve.core.async_engine import AsyncEngine
from transformers import PreTrainedTokenizer
from typing import Any, Dict, AsyncGenerator, List, Optional
from lmdeploy.archs import get_task
from gpt_server.model_handler.reasoning_parser import ReasoningParserManager
from loguru import logger
from gpt_server.model_backend.base import ModelBackend
from gpt_server.settings import get_model_config
from lmdeploy.logger import RequestLogger
from lmdeploy.utils import get_logger

if sys.platform == "linux":
    # 防止Python c库没有加载导致lmdeploy pytorch后端报错
    os.environ["C_INCLUDE_PATH"] = "/usr/include/python3.8:" + (
        os.environ.get("C_INCLUDE_PATH", "")
    )
    os.environ["LUS_INCLUDE_PATH"] = "/usr/include/python3.8:" + (
        os.environ.get("LUS_INCLUDE_PATH", "")
    )
backend_map = {
    "lmdeploy-pytorch": "pytorch",  # pytorch后端
    "lmdeploy-turbomind": "turbomind",  # turbomind后端
}
# ------- 日志控制 -------
log_level = os.getenv("log_level", "WARNING")


get_logger("lmdeploy").setLevel(log_level)  # 默认WARNING
os.environ["TM_LOG_LEVEL"] = "WARNING"
# ------- 日志控制 -------


class CustomRequestLogger(RequestLogger):
    def log_prompt(self, session_id: int, prompt: str) -> None:
        if not isinstance(prompt, str):
            # Prompt may be a GPT4V message with base64 images;
            # logging might be impractical due to length
            return

    def log_inputs(
        self,
        session_id: int,
        prompt: Optional[str],
        prompt_token_ids: Optional[List[int]],
        gen_config: GenerationConfig,
        adapter_name: str,
    ) -> None:
        max_log_len = self.max_log_len
        input_tokens = len(prompt_token_ids)
        if max_log_len is not None:
            if prompt is not None:
                prompt = prompt[:max_log_len]

            if prompt_token_ids is not None:
                prompt_token_ids = prompt_token_ids[:max_log_len]

        logger.info(
            f"session_id={session_id} adapter_name={adapter_name} gen_config={gen_config}"
        )
        logger.info(f"prompt:\n{prompt}")


class LMDeployBackend(ModelBackend):
    def __init__(self, model_path, tokenizer: PreTrainedTokenizer) -> None:
        model_config = get_model_config()
        logger.info(f"model_config: {model_config}")
        backend = backend_map[model_config.backend]
        logger.info(f"后端 {backend}")
        if backend == "pytorch":
            backend_config = PytorchEngineConfig(
                tp=model_config.num_gpus,
                dtype=model_config.dtype,
                session_len=model_config.max_model_len,
                enable_prefix_caching=model_config.enable_prefix_caching,
                cache_max_entry_count=model_config.gpu_memory_utilization,
                quant_policy=model_config.kv_cache_quant_policy,
            )
        if backend == "turbomind":
            backend_config = TurbomindEngineConfig(
                tp=model_config.num_gpus,
                enable_prefix_caching=model_config.enable_prefix_caching,
                session_len=model_config.max_model_len,
                dtype=model_config.dtype,
                cache_max_entry_count=model_config.gpu_memory_utilization,
                quant_policy=model_config.kv_cache_quant_policy,  # 默认为:0
            )
        pipeline_type, pipeline_class = get_task(model_path)
        logger.info(f"模型架构:{pipeline_type}")
        self.async_engine: AsyncEngine = pipeline_class(
            model_path=model_path,
            backend=backend,
            backend_config=backend_config,
        )
        self.tokenizer = self.async_engine.tokenizer
        self.reasoning_parser_cache = {}
        # 自定义日志
        self.async_engine.request_logger = CustomRequestLogger(max_log_len=None)

    def shutdown(self):
        logger.info("lmdeploy后端退出")

    async def stream_chat(self, params: Dict[str, Any]) -> AsyncGenerator:
        # params 已不需要传入 prompt
        messages = params["messages"]
        request_id = params.get("request_id", "0")
        temperature = float(params.get("temperature", 0.8))
        top_p = float(params.get("top_p", 0.8))
        top_k = params.get("top_k", 50)
        max_new_tokens = int(params.get("max_new_tokens", 1024 * 8))
        stop_str = params.get("stop", None)
        stop_token_ids = params.get("stop_words_ids", None) or []
        presence_penalty = float(params.get("presence_penalty", 0.0))
        frequency_penalty = float(params.get("frequency_penalty", 0.0))
        reasoning_parser_type = params.get("reasoning_parser", None)
        request = params.get("request", None)
        enable_thinking = bool(params.get("enable_thinking", True))
        tools = params.get("tools", None)
        chat_template = params.get("chat_template", None)
        # Handle stop_str
        stop = set()
        if isinstance(stop_str, str) and stop_str != "":
            stop.add(stop_str)
        elif isinstance(stop_str, list) and stop_str != []:
            stop.update(stop_str)
        # prompt_token_ids = input_ids.tolist()[0]
        # make sampling params in vllm
        top_p = max(top_p, 1e-5)
        gen_config = GenerationConfig(
            do_sample=True,
            top_p=top_p,
            temperature=temperature,
            max_new_tokens=max_new_tokens,  # 存在问题
            top_k=50 if top_k == -1 else top_k,
            stop_words=list(stop),
            skip_special_tokens=True,
            response_format=params["response_format"],
        )

        results_generator = self.async_engine.generate(
            messages=messages,
            session_id=int(request_id),
            gen_config=gen_config,
            enable_thinking=enable_thinking,
            tools=tools,
            chat_template=chat_template,
        )
        usage = {}
        previous_text = ""
        current_text = ""
        previous_token_ids = []
        current_token_ids = []
        delta_token_ids = []
        async for request_output in results_generator:
            current_text = current_text + request_output.response

            usage = {
                "prompt_tokens": request_output.input_token_len,
                "completion_tokens": request_output.generate_token_len,
                "total_tokens": request_output.input_token_len
                + request_output.generate_token_len,
            }
            ret = {
                "text": request_output.response,
                "error_code": 0,
                "usage": usage,
                "finish_reason": request_output.finish_reason,
            }

            if reasoning_parser_type:
                reasoning_parser = None
                delta_token_ids = (
                    request_output.token_ids
                    if request_output.token_ids is not None
                    else []
                )
                current_token_ids = current_token_ids + delta_token_ids
                if reasoning_parser_type in self.reasoning_parser_cache:
                    reasoning_parser = self.reasoning_parser_cache.get(
                        reasoning_parser_type
                    )
                else:
                    reasoning_parser = ReasoningParserManager.get(
                        reasoning_parser_type
                    )(self.tokenizer)
                    self.reasoning_parser_cache[reasoning_parser_type] = (
                        reasoning_parser
                    )
                reasoning_delta = reasoning_parser.extract_reasoning_content_streaming(
                    previous_text=previous_text,
                    current_text=current_text,
                    delta_text=request_output.response,
                    previous_token_ids=previous_token_ids,
                    current_token_ids=current_token_ids,
                    delta_token_ids=delta_token_ids,
                )
                if reasoning_delta is not None:
                    ret["text"] = (
                        reasoning_delta.content if reasoning_delta.content else ""
                    )
                    ret["reasoning_content"] = (
                        reasoning_delta.reasoning_content
                        if reasoning_delta.reasoning_content
                        else ""
                    )
                previous_token_ids = current_token_ids

            if not ret["text"] and not ret.get("reasoning_content", ""):
                continue
            yield ret
            previous_text = current_text
        logger.info(current_text)
        logger.info(usage)


================================================
FILE: gpt_server/model_backend/sglang_backend.py
================================================
import asyncio
import json
from typing import Any, AsyncGenerator, Dict

from loguru import logger
from sglang.srt.entrypoints.engine import (
    _launch_subprocesses,
    init_tokenizer_manager,
    run_detokenizer_process,
    run_scheduler_process,
)
from sglang.srt.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    ErrorResponse,
    MessageProcessingResult,
    ResponsesRequest,
    StreamOptions,
)
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
from sglang.srt.entrypoints.openai.serving_responses import OpenAIServingResponses
from sglang.srt.server_args import ServerArgs
from starlette.responses import StreamingResponse
from transformers import PreTrainedTokenizer

from gpt_server.model_backend.base import ModelBackend
from gpt_server.settings import get_model_config


class CustomOpenAIServingResponses(OpenAIServingResponses):
    def _process_messages(self, request, is_multimodal):
        value: MessageProcessingResult = super()._process_messages(
            request, is_multimodal
        )
        prompt = value.prompt
        logger.info("prompt:\n" + prompt)
        return value


class CustomOpenAIServingChat(OpenAIServingChat):
    def _process_messages(self, request, is_multimodal):
        value: MessageProcessingResult = super()._process_messages(
            request, is_multimodal
        )
        prompt = value.prompt
        logger.info("prompt:\n" + prompt)
        return value


class SGLangBackend(ModelBackend):
    def __init__(self, model_path, tokenizer: PreTrainedTokenizer) -> None:
        model_config = get_model_config()
        self.lora_requests = []
        self.model_path = model_path
        # ---
        kwargs = {
            "model_path": model_path,
            "trust_remote_code": True,
            "mem_fraction_static": model_config.gpu_memory_utilization,
            "tp_size": model_config.num_gpus,
            "dtype": model_config.dtype,
            "context_length": model_config.max_model_len,
            "grammar_backend": "xgrammar",
            "disable_radix_cache": not model_config.enable_prefix_caching,
            # https://docs.sglang.io/advanced_features/separate_reasoning.html
            "reasoning_parser": model_config.reasoning_parser,
            "tool_call_parser": model_config.tool_call_parser,
            "speculative_algorithm": model_config.speculative_algorithm,
            "speculative_num_steps": model_config.speculative_num_steps,
            "speculative_eagle_topk": 1 if model_config.speculative_algorithm else None,
            "disable_cuda_graph": model_config.enforce_eager,
        }
        server_args = ServerArgs(**kwargs)

        tokenizer_manager, template_manager, scheduler_infos, port_args = (
            _launch_subprocesses(
                server_args=server_args,
                init_tokenizer_manager_func=init_tokenizer_manager,
                run_scheduler_process_func=run_scheduler_process,
                run_detokenizer_process_func=run_detokenizer_process,
            )
        )
        self.tokenizer_manager = tokenizer_manager
        self.serving_chat = CustomOpenAIServingChat(
            tokenizer_manager=tokenizer_manager, template_manager=template_manager
        )
        # ---
        self.serving_responses = CustomOpenAIServingResponses(
            tokenizer_manager=tokenizer_manager, template_manager=template_manager
        )

    def shutdown(self):
        logger.info("sglang后端退出")

    async def stream_chat(self, params: Dict[str, Any]) -> AsyncGenerator:

        api_type = params.get("api_type", "chat")
        try:
            if api_type == "chat":
                # params 已不需要传入 prompt
                messages = params.get("messages", [])
                tools = params.get("tools", None)
                chat_template = params.get("chat_template", None)
                enable_thinking = bool(params.get("enable_thinking", True))
                request_id = params.get("request_id", "0")
                temperature = float(params.get("temperature", 0.8))
                top_p = float(params.get("top_p", 0.8))
                top_k = params.get("top_k", -1)
                max_new_tokens = int(params.get("max_new_tokens", 1024 * 8))
                stop_str = params.get("stop", None)
                stop_token_ids = params.get("stop_words_ids", None) or []
                presence_penalty = float(params.get("presence_penalty", 0.0))
                frequency_penalty = float(params.get("frequency_penalty", 0.0))
                request = params.get("request", None)
                # ---- 支持 response_format ----
                response_format = params.get("response_format", None)
                # ------
                # Handle stop_str
                stop = set()
                if isinstance(stop_str, str) and stop_str != "":
                    stop.add(stop_str)
                elif isinstance(stop_str, list) and stop_str != []:
                    stop.update(stop_str)
                if tools:
                    for t in tools:
                        if t["function"].get("strict", None) is None:
                            t["function"]["strict"] = False
                request = ChatCompletionRequest(
                    messages=messages,
                    model=self.model_path,
                    max_tokens=max_new_tokens,
                    temperature=temperature,
                    seed=33,
                    stream=True,
                    stream_options=StreamOptions(
                        include_usage=True, continuous_usage_stats=True
                    ),
                    tools=tools,
                    response_format=response_format,
                    stop_token_ids=stop_token_ids,
                    stop=stop,
                    presence_penalty=presence_penalty,
                    frequency_penalty=frequency_penalty,
                    top_k=top_k,
                    top_p=top_p if top_p != 0 else 0.01,
                    rid=request_id,
                    # tool_choice=params.get("tool_choice", "auto"),
                    chat_template_kwargs={"enable_thinking": enable_thinking},
                )

                response = await self.serving_chat.handle_request(
                    request=request, raw_request=None
                )

                if isinstance(response, StreamingResponse):
                    output_text = ""
                    reasoning_content_text = ""
                    pre_usage = None
                    async for chunk in response.body_iterator:
                        # data: {"id":"chatcmpl-bf6de7d56c9bfecc","object":"chat.completion.chunk","created":1769947499,"model":"qwem3vl","choices":[{"index":0,"delta":{"content":"你好","reasoning_content":null},"logprobs":null,"finish_reason":null,"token_ids":null}],"usage":{"prompt_tokens":10,"total_tokens":11,"completion_tokens":1}}
                        # data: [DONE]
                        chunk = chunk.strip("data: ").strip()
                        if chunk == "[DONE]":
                            break
                        chunk_dict = json.loads(chunk)
                        choices = chunk_dict["choices"]
                        if not choices:
                            continue
                        usage = chunk_dict["usage"]
                        if usage is None and pre_usage is not None:
                            usage = pre_usage
                        pre_usage = usage
                        tool_calls = None
                        try:
                            reasoning_content = choices[0]["delta"].get(
                                "reasoning_content", None
                            )
                            text = choices[0]["delta"]["content"]
                            # 提取 tool_calls
                            tool_calls = choices[0]["delta"].get("tool_calls", None)
                            if text is None:
                                text = ""
                        except Exception:
                            logger.error(
                                f"Error in processing chunk: {chunk_dict}",
                            )
                        output_text += text
                        if reasoning_content:
                            reasoning_content_text += reasoning_content
                        ret = {
                            "text": text,
                            "usage": usage,
                            "error_code": 0,
                            "finish_reason": choices[0]["finish_reason"],
                            "reasoning_content": reasoning_content,
                            "tool_calls": tool_calls,
                        }
                        yield ret
                    logger.info(f"reasoning_content: \n{reasoning_content_text}")
                    logger.info(f"output_text: \n{output_text}")
                    logger.info(f"usage: {usage}")

                elif isinstance(response, ErrorResponse):
                    pass

            else:
                request_dict = params.get("responses_request", None)
                request = ResponsesRequest.model_validate(request_dict)
                request.model = self.model_path
                if request.stream:
                    response = await self.serving_responses.create_responses(
                        request, raw_request=None
                    )
                    async for chunk in response:
                        yield chunk
                else:
                    response = await self.serving_responses.create_responses(
                        request, raw_request=None
                    )
                    data = response.model_dump_json(exclude_unset=True)
                    yield data
        except asyncio.CancelledError as e:
            self.tokenizer_manager.abort_request(request_id)
            logger.warning(f"request_id : {request_id} 已中断!")


================================================
FILE: gpt_server/model_backend/utils.py
================================================
from typing import List, Type, Union
from pydantic import BaseModel
from transformers.generation.logits_process import LogitsProcessor
from transformers import PreTrainedTokenizerBase
from transformers.generation.stopping_criteria import (
    StoppingCriteria,
    StoppingCriteriaList,
    STOPPING_CRITERIA_INPUTS_DOCSTRING,
    add_start_docstrings,
)
import xgrammar as xgr
import torch


class XgrammarLogitsProcessor(LogitsProcessor):
    def __init__(self, tokenizer: PreTrainedTokenizerBase):
        tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer)
        self.grammar_compiler = xgr.GrammarCompiler(tokenizer_info)
        # -----------

    def get_json_grammar_processor(self):
        compiled_grammar = self.grammar_compiler.compile_builtin_json_grammar()
        self.xgr_logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar)
        return self.xgr_logits_processor

    def get_json_schema_processor(self, schema: Union[str, Type[BaseModel]]):
        compiled_grammar = self.grammar_compiler.compile_json_schema(schema)
        self.xgr_logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar)
        return self.xgr_logits_processor

    def __call__(
        self, input_ids: torch.LongTensor, scores: torch.FloatTensor
    ) -> torch.FloatTensor:
        return self.xgr_logits_processor(input_ids=input_ids, scores=scores)


class InvalidScoreLogitsProcessor(LogitsProcessor):
    def __call__(
        self, input_ids: torch.LongTensor, scores: torch.FloatTensor
    ) -> torch.FloatTensor:
        if torch.isnan(scores).any() or torch.isinf(scores).any():
            scores.zero_()
            scores[..., 5] = 5e4
        return scores


class StopAtSpecificTokenCriteria(StoppingCriteria):
    """
    当生成出第一个指定token时,立即停止生成
    """

    def __init__(self, token_id_list: List[int] = None):
        """
        :param token_id_list: 停止生成的指定token的id的列表
        """
        self.token_id_list = token_id_list
        self.stop = False

    @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
    def __call__(
        self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
    ) -> bool:
        # return np.argmax(scores[-1].detach().cpu().numpy()) in self.token_id_list
        # 储存scores会额外占用资源,所以直接用input_ids进行判断
        if self.stop:
            return True
        return input_ids[0][-1].detach().cpu().numpy() in self.token_id_list


================================================
FILE: gpt_server/model_backend/vllm_backend.py
================================================
from dataclasses import asdict
import json
from typing import Any, AsyncGenerator, Dict

from loguru import logger
from transformers import PreTrainedTokenizer
from vllm import AsyncEngineArgs, AsyncLLMEngine
from vllm.config.structured_outputs import StructuredOutputsConfig
from vllm.entrypoints.chat_utils import ConversationMessage
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import StreamOptions
from vllm.entrypoints.openai.models.serving import BaseModelPath, OpenAIServingModels
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses
from vllm.inputs.data import TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.sampling_params import StructuredOutputsParams

from gpt_server.model_backend.base import ModelBackend
from gpt_server.settings import get_model_config


class CustomOpenAIServingResponses(OpenAIServingResponses):
    async def _preprocess_chat(self, *args, **kwargs):
        value: tuple[list[ConversationMessage], list[TokensPrompt]] = (
            await super()._preprocess_chat(*args, **kwargs)
        )
        prompts: TokensPrompt = value[1][0]
        prompt = prompts.get("prompt", None)
        if prompt:
            logger.info("prompt:\n" + prompt)
        return value


class CustomOpenAIServingChat(OpenAIServingChat):
    async def render_chat_request(self, request):
        value = await super().render_chat_request(request)
        try:
            prompt = value[1][0]["prompt"]
            logger.info("prompt:\n" + prompt)
        except Exception:
            logger.error("request:\n" + str(value))
        return value


class VllmBackend(ModelBackend):
    def __init__(self, model_path, tokenizer: PreTrainedTokenizer) -> None:
        self.model_path = model_path
        model_config = get_model_config()
        logger.info(f"model_config: {model_config}")
        max_loras = 1
        enable_lora = False
        self.lora_requests = []
        if model_config.lora:
            enable_lora = True
            lora_dict: dict = json.loads(model_config.lora)
            max_loras = len(lora_dict)
            for i, (lora_name, lora_path) in enumerate(lora_dict.items()):
                self.lora_requests.append(
                    LoRARequest(
                        lora_name=lora_name,
                        lora_int_id=i,
                        lora_local_path=lora_path,
                    )
                )
        # from vllm.config.kv_transfer import KVTransferConfig

        self.engine_args = AsyncEngineArgs(
            model_path,
            tensor_parallel_size=model_config.num_gpus,
            trust_remote_code=True,
            gpu_memory_utilization=model_config.gpu_memory_utilization,
            enable_chunked_prefill=model_config.enable_chunked_prefill,
            enable_lora=enable_lora,
            max_loras=max_loras,
            enable_prefix_caching=model_config.enable_prefix_caching,
            dtype=model_config.dtype,
            max_model_len=model_config.max_model_len,
            # guided_decoding_backend="xgrammar",
            # 支持LMCache的KV传输
            # kv_transfer_config=KVTransferConfig(
            #     kv_connector="LMCacheConnectorV1", kv_role="kv_both"
            # ),
            prefix_caching_hash_algo="xxhash",
            structured_outputs_config=StructuredOutputsConfig(backend="xgrammar"),
            enforce_eager=model_config.enforce_eager,
        )
        self.engine = AsyncLLMEngine.from_engine_args(self.engine_args)
        models = OpenAIServingModels(
            engine_client=self.engine,
            base_model_paths=[
                BaseModelPath(name=self.model_path, model_path=self.model_path)
            ],
            lora_modules=None,
        )
        self.serving_chat = CustomOpenAIServingChat(
            engine_client=self.engine,
            models=models,
            response_role="assistant",
            chat_template=None,
            chat_template_content_format="auto",
            request_logger=None,
            trust_request_chat_template=True,
            enable_auto_tools=True,
            tool_parser=model_config.tool_call_parser,
            # https://docs.vllm.ai/en/latest/features/reasoning_outputs/
            reasoning_parser=(
                model_config.reasoning_parser if model_config.reasoning_parser else ""
            ),
        )
        self.serving_responses = CustomOpenAIServingResponses(
            engine_client=self.engine,
            models=models,
            chat_template=None,
            chat_template_content_format="auto",
            request_logger=None,
            enable_auto_tools=True,
            tool_parser=None,
            # https://docs.vllm.ai/en/latest/features/reasoning_outputs/
            reasoning_parser=(
                model_config.reasoning_parser if model_config.reasoning_parser else ""
            ),
        )

    def shutdown(self):
        self.engine.shutdown()
        logger.info("vllm后端退出")

    async def stream_chat(self, params: Dict[str, Any]) -> AsyncGenerator:

        api_type = params.get("api_type", "chat")
        if api_type == "chat":
            # params 已不需要传入 prompt
            messages = params["messages"]
            request_id = params.get("request_id", "0")
            temperature = float(params.get("temperature", 0.8))
            top_p = float(params.get("top_p", 0.8))
            top_k = int(params.get("top_k", 0))
            max_new_tokens = int(params.get("max_new_tokens", 1024 * 8))
            stop_str = params.get("stop", None)
            stop_token_ids = params.get("stop_words_ids", None) or []
            presence_penalty = float(params.get("presence_penalty", 0.0))
            frequency_penalty = float(params.get("frequency_penalty", 0.0))
            repetition_penalty = float(params.get("repetition_penalty", 1.0))
            enable_thinking = bool(params.get("enable_thinking", True))
            request = params.get("request", None)
            tools = params.get("tools", None)
            chat_template = params.get("chat_template", None)
            # Handle stop_str
            stop = set()
            if isinstance(stop_str, str) and stop_str != "":
                stop.add(stop_str)
            elif isinstance(stop_str, list) and stop_str != []:
                stop.update(stop_str)

            # ----------------------------------------------------------------
            # make sampling params in vllm
            top_p = max(top_p, 1e-5)
            if temperature <= 1e-5:
                top_p = 1.0
                temperature = 0.01
            response_format = params["response_format"]
            guided_json_object = None
            guided_decoding = None
            guided_json = None
            if response_format is not None:
                if response_format["type"] == "json_object":
                    guided_json_object = True
                if response_format["type"] == "json_schema":
                    json_schema = response_format["json_schema"]
                    assert json_schema is not None
                    guided_json = json_schema["schema"]
                guided_decoding = StructuredOutputsParams(
                    json=guided_json,
                    regex=None,
                    choice=None,
                    grammar=None,
                    json_object=guided_json_object,
                    whitespace_pattern=None,
                )
                if response_format["type"] == "text":
                    guided_decoding = None

            lora_request = None
            for lora in self.lora_requests:
                if params["model"] == lora.lora_name:
                    lora_request = lora
                    break

            request = ChatCompletionRequest(
                model=self.model_path,
                messages=messages,
                seed=33,
                stream=True,
                stream_options=StreamOptions(
                    include_usage=True, continuous_usage_stats=True
                ),
                max_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                top_k=top_k,
                presence_penalty=presence_penalty,
                frequency_penalty=frequency_penalty,
                repetition_penalty=repetition_penalty,
                stop=stop,
                stop_token_ids=stop_token_ids,
                structured_outputs=asdict(guided_decoding) if guided_decoding else None,
                request_id=request_id,
                tools=tools,
                # tool_choice=params.get("tool_choice", None),
                chat_template_kwargs={"enable_thinking": enable_thinking},
            )
            response = await self.serving_chat.create_chat_completion(
                request=request,
                raw_request=None,
            )
            output_text = ""
            reasoning_content_text = ""
            async for chunk in response:
                # data: {"id":"chatcmpl-bf6de7d56c9bfecc","object":"chat.completion.chunk","created":1769947499,"model":"qwem3vl","choices":[{"index":0,"delta":{"content":"你好","reasoning_content":null},"logprobs":null,"finish_reason":null,"token_ids":null}],"usage":{"prompt_tokens":10,"total_tokens":11,"completion_tokens":1}}
                # data: [DONE]
                chunk = chunk.strip("data: ").strip()
                if chunk == "[DONE]":
                    break
                chunk_dict = json.loads(chunk)
                choices = chunk_dict["choices"]
                if not choices:
                    continue
                usage = chunk_dict["usage"]
                reasoning_content = None
                tool_calls = None
                try:
                    text = choices[0]["delta"]["content"]
                    reasoning_content = choices[0]["delta"].get(
                        "reasoning_content", None
                    )
                    tool_calls = choices[0]["delta"].get("tool_calls", None)
                except Exception:
                    logger.error(
                        f"Error in processing chunk: {chunk_dict}",
                    )
                output_text += text
                if reasoning_content:
                    reasoning_content_text += reasoning_content
                ret = {
                    "text": text,
                    "usage": usage,
                    "error_code": 0,
                    "finish_reason": choices[0]["finish_reason"],
                    "reasoning_content": reasoning_content,
                    "tool_calls": tool_calls,
                }
                yield ret

            # logger.info(f"Lora: {request_output.lora_request}")
            logger.info(f"reasoning_content: \n{reasoning_content_text}")
            logger.info(f"output_text: \n{output_text}")
            logger.info(f"usage: {usage}")
        else:
            request_dict = params.get("responses_request", None)
            request = ResponsesRequest.model_validate(request_dict)
            request.model = self.model_path
            if request.stream:
                response = await self.serving_responses.create_responses(request)
                async for chunk in response:
                    data = chunk.model_dump_json(exclude_unset=True)
                    yield f"data: {data}\n\n"
            else:
                response = await self.serving_responses.create_responses(request)
                data = response.model_dump_json(exclude_unset=True)
                yield data


if __name__ == "__main__":
    s = 'data: {"id":"chatcmpl-bf6de7d56c9bfecc","object":"chat.completion.chunk","created":1769947499,"model":"qwem3vl","choices":[{"index":0,"delta":{"content":"你好","reasoning_content":null},"logprobs":null,"finish_reason":null,"token_ids":null}],"usage":{"prompt_tokens":10,"total_tokens":11,"completion_tokens":1}}'
    v = s.strip("data: ").strip()
    import json

    print(json.loads(v))


================================================
FILE: gpt_server/model_handler/__init__.py
================================================


================================================
FILE: gpt_server/model_handler/chat_template/get_chat_template.py
================================================
from pathlib import Path
from typing import Literal

cur_path = Path(__file__).parent


def get_chat_template(model_name: str = "", lang: Literal["en", "zh"] = "en") -> str:
    """获取chat_template

    Parameters
    ----------
    model_name : str
        模型名称
    lang : str, optional
        语言, by default en

    Returns
    -------
    str
        chat_template
    """
    suffix = ""
    if lang == "zh":
        suffix = "_zh"
    if model_name in ["qwen3", "qwen2_5", "qwen"]:
        with open(cur_path / f"qwen3{suffix}.jinja", "r", encoding="utf8") as f:
            return f.read()


if __name__ == "__main__":

    chat_template = get_chat_template("qwen3", lang="zh")
    print(chat_template)


================================================
FILE: gpt_server/model_handler/chat_template/qwen3.jinja
================================================
{%- if tools %}
    {{- '<|im_start|>system\n' }}
    {%- if messages[0].role == 'system' %}
        {{- messages[0].content + '\n\n' }}
    {%- else %}
        {{- 'You are a helpful assistant. \n\n' }}
    {%- endif %}
    {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
    {%- for tool in tools %}
        {{- "\n" }}
        {{- tool | tojson }}
    {%- endfor %}
    {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
    {%- if messages[0].role == 'system' %}
        {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
    {%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
{%- for message in messages[::-1] %}
    {%- set index = (messages|length - 1) - loop.index0 %}
    {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
        {%- set ns.multi_step_tool = false %}
        {%- set ns.last_query_index = index %}
    {%- endif %}
{%- endfor %}
{%- for message in messages %}
    {%- if message.content is string %}
        {%- set content = message.content %}
    {%- else %}
        {%- set content = '' %}
    {%- endif %}
    {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
        {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
    {%- elif message.role == "assistant" %}
        {%- set reasoning_content = '' %}
        {%- if message.reasoning_content is string %}
            {%- set reasoning_content = message.reasoning_content %}
        {%- else %}
            {%- if '</think>' in content %}
                {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
                {%- set content = content.split('</think>')[-1].lstrip('\n') %}
            {%- endif %}
        {%- endif %}
        {%- if loop.index0 > ns.last_query_index %}
            {%- if loop.last or (not loop.last and reasoning_content) %}
                {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
            {%- else %}
                {{- '<|im_start|>' + message.role + '\n' + content }}
            {%- endif %}
        {%- else %}
            {{- '<|im_start|>' + message.role + '\n' + content }}
        {%- endif %}
        {%- if message.tool_calls %}
            {%- for tool_call in message.tool_calls %}
                {%- if (loop.first and content) or (not loop.first) %}
                    {{- '\n' }}
                {%- endif %}
                {%- if tool_call.function %}
                    {%- set tool_call = tool_call.function %}
                {%- endif %}
                {{- '<tool_call>\n{"name": "' }}
                {{- tool_call.name }}
                {{- '", "arguments": ' }}
                {%- if tool_call.arguments is string %}
                    {{- tool_call.arguments }}
                {%- else %}
                    {{- tool_call.arguments | tojson }}
                {%- endif %}
                {{- '}\n</tool_call>' }}
            {%- endfor %}
        {%- endif %}
        {{- '<|im_end|>\n' }}
    {%- elif message.role == "tool" %}
        {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
            {{- '<|im_start|>user' }}
        {%- endif %}
        {{- '\n<tool_response>\n' }}
        {{- content }}
        {{- '\n</tool_response>' }}
        {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
            {{- '<|im_end|>\n' }}
        {%- endif %}
    {%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
    {{- '<|im_start|>assistant\n' }}
    {%- if enable_thinking is defined and enable_thinking is false %}
        {{- '<think>\n\n</think>\n\n' }}
    {%- endif %}
{%- endif %}

================================================
FILE: gpt_server/model_handler/chat_template/qwen3_zh.jinja
================================================
{%- if tools %}
    {{- '<|im_start|>system\n' }}
    {%- if messages[0].role == 'system' %}
        {{- messages[0].content + '\n\n' }}
    {%- else %}
        {{- 'You are a helpful assistant. \n\n' }}
    {%- endif %}
    {{- "# Tools\n\n你每次只能调用一个function来协助处理用户查询。\n\n在<tools></tools> XML标签中提供了function的签名(即函数的结构信息):\n<tools>" }}
    {%- for tool in tools %}
        {{- "\n" }}
        {{- tool | tojson }}
    {%- endfor %}
    {{- "\n</tools>\n\n对于单个function的调用, 返回一个包含function name和参数的 JSON 对象,并用 <tool_call></tool_call> XML 标签包裹,形如:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
    {%- if messages[0].role == 'system' %}
        {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
    {%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
{%- for message in messages[::-1] %}
    {%- set index = (messages|length - 1) - loop.index0 %}
    {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
        {%- set ns.multi_step_tool = false %}
        {%- set ns.last_query_index = index %}
    {%- endif %}
{%- endfor %}
{%- for message in messages %}
    {%- if message.content is string %}
        {%- set content = message.content %}
    {%- else %}
        {%- set content = '' %}
    {%- endif %}
    {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
        {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
    {%- elif message.role == "assistant" %}
        {%- set reasoning_content = '' %}
        {%- if message.reasoning_content is string %}
            {%- set reasoning_content = message.reasoning_content %}
        {%- else %}
            {%- if '</think>' in content %}
                {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
                {%- set content = content.split('</think>')[-1].lstrip('\n') %}
            {%- endif %}
        {%- endif %}
        {%- if loop.index0 > ns.last_query_index %}
            {%- if loop.last or (not loop.last and reasoning_content) %}
                {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
            {%- else %}
                {{- '<|im_start|>' + message.role + '\n' + content }}
            {%- endif %}
        {%- else %}
            {{- '<|im_start|>' + message.role + '\n' + content }}
        {%- endif %}
        {%- if message.tool_calls %}
            {%- for tool_call in message.tool_calls %}
                {%- if (loop.first and content) or (not loop.first) %}
                    {{- '\n' }}
                {%- endif %}
                {%- if tool_call.function %}
                    {%- set tool_call = tool_call.function %}
                {%- endif %}
                {{- '<tool_call>\n{"name": "' }}
                {{- tool_call.name }}
                {{- '", "arguments": ' }}
                {%- if tool_call.arguments is string %}
                    {{- tool_call.arguments }}
                {%- else %}
                    {{- tool_call.arguments | tojson }}
                {%- endif %}
                {{- '}\n</tool_call>' }}
            {%- endfor %}
        {%- endif %}
        {{- '<|im_end|>\n' }}
    {%- elif message.role == "tool" %}
        {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
            {{- '<|im_start|>user' }}
        {%- endif %}
        {{- '\n<tool_response>\n' }}
        {{- content }}
        {{- '\n</tool_response>' }}
        {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
            {{- '<|im_end|>\n' }}
        {%- endif %}
    {%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
    {{- '<|im_start|>assistant\n' }}
    {%- if enable_thinking is defined and enable_thinking is false %}
        {{- '<think>\n\n</think>\n\n' }}
    {%- endif %}
{%- endif %}

================================================
FILE: gpt_server/model_handler/chat_template/qwen3vl.jinja
================================================
{%- set image_count = namespace(value=0) %}
{%- set video_count = namespace(value=0) %}
{%- macro render_content(content, do_vision_count) %}
    {%- if content is string %}
        {{- content }}
    {%- else %}
        {%- for item in content %}
            {%- if 'image' in item or 'image_url' in item or item.type == 'image' %}
                {%- if do_vision_count %}
                    {%- set image_count.value = image_count.value + 1 %}
                {%- endif %}
                {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}
                <|vision_start|><|image_pad|><|vision_end|>
            {%- elif 'video' in item or item.type == 'video' %}
                {%- if do_vision_count %}
                    {%- set video_count.value = video_count.value + 1 %}
                {%- endif %}
                {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}
                <|vision_start|><|video_pad|><|vision_end|>
            {%- elif 'text' in item %}
                {{- item.text }}
            {%- endif %}
        {%- endfor %}
    {%- endif %}
{%- endmacro %}
{%- if tools %}
    {{- '<|im_start|>system\n' }}
    {%- if messages[0].role == 'system' %}
        {{- render_content(messages[0].content, false) + '\n\n' }}
    {%- endif %}
    {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
    {%- for tool in tools %}
        {{- "\n" }}
        {{- tool | tojson }}
    {%- endfor %}
    {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
    {%- if messages[0].role == 'system' %}
        {{- '<|im_start|>system\n' + render_content(messages[0].content, false) + '<|im_end|>\n' }}
    {%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
{%- for message in messages[::-1] %}
    {%- set index = (messages|length - 1) - loop.index0 %}
    {%- if ns.multi_step_tool and message.role == "user" %}
        {%- set content = render_content(message.content, false) %}
        {%- if not(content.startswith('<tool_response>') and content.endswith('</tool_response>')) %}
            {%- set ns.multi_step_tool = false %}
            {%- set ns.last_query_index = index %}
        {%- endif %}
    {%- endif %}
{%- endfor %}
{%- for message in messages %}
    {%- set content = render_content(message.content, True) %}
    {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
        {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
    {%- elif message.role == "assistant" %}
        {%- set reasoning_content = '' %}
        {%- if message.reasoning_content is string %}
            {%- set reasoning_content = message.reasoning_content %}
        {%- else %}
            {%- if '</think>' in content %}
                {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
                {%- set content = content.split('</think>')[-1].lstrip('\n') %}
            {%- endif %}
        {%- endif %}
        {%- if loop.index0 > ns.last_query_index %}
            {%- if loop.last or (not loop.last and reasoning_content) %}
                {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
            {%- else %}
                {{- '<|im_start|>' + message.role + '\n' + content }}
            {%- endif %}
        {%- else %}
            {{- '<|im_start|>' + message.role + '\n' + content }}
        {%- endif %}
        {%- if message.tool_calls %}
            {%- for tool_call in message.tool_calls %}
                {%- if (loop.first and content) or (not loop.first) %}
                    {{- '\n' }}
                {%- endif %}
                {%- if tool_call.function %}
                    {%- set tool_call = tool_call.function %}
                {%- endif %}
                {{- '<tool_call>\n{"name": "' }}
                {{- tool_call.name }}
                {{- '", "arguments": ' }}
                {%- if tool_call.arguments is string %}
                    {{- tool_call.arguments }}
                {%- else %}
                    {{- tool_call.arguments | tojson }}
                {%- endif %}
                {{- '}\n</tool_call>' }}
            {%- endfor %}
        {%- endif %}
        {{- '<|im_end|>\n' }}
    {%- elif message.role == "tool" %}
        {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
            {{- '<|im_start|>user' }}
        {%- endif %}
        {{- '\n<tool_response>\n' }}
        {{- content }}
        {{- '\n</tool_response>' }}
        {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
            {{- '<|im_end|>\n' }}
        {%- endif %}
    {%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
    {{- '<|im_start|>assistant\n<think>\n' }}
    {%- if enable_thinking is defined and enable_thinking is false %}
        {{- '<think>\n\n</think>\n\n' }}
    {%- endif %}
{%- endif %}

================================================
FILE: gpt_server/model_handler/pitch.py
================================================
from typing import Optional
from flashtts.llm.vllm_generator import VllmGenerator
import flashtts
from loguru import logger


class VllmGenerator_(VllmGenerator):
    def __init__(
        self,
        model_path: str,
        max_length: int = 32768,
        gpu_memory_utilization: float = 0.6,
        device: str = "cuda",
        stop_tokens: Optional[list[str]] = None,
        stop_token_ids: Optional[list[int]] = None,
        **kwargs,
    ):
        from vllm import AsyncEngineArgs, AsyncLLMEngine

        engine_kwargs = dict(
            model=model_path,
            max_model_len=max_length,
            gpu_memory_utilization=gpu_memory_utilization,
            # device=device,
            disable_log_stats=True,
            # disable_log_requests=True,
            **kwargs,
        )
        async_args = AsyncEngineArgs(**engine_kwargs)

        self.model = AsyncLLMEngine.from_engine_args(async_args)

        super(VllmGenerator, self).__init__(
            tokenizer=model_path,
            max_length=max_length,
            stop_tokens=stop_tokens,
            stop_token_ids=stop_token_ids,
        )


def pitch_flashtts():
    flashtts.llm.vllm_generator.VllmGenerator = VllmGenerator_
    logger.info("patch flashtts.llm.vllm_generator.VllmGenerator")


================================================
FILE: gpt_server/model_handler/reasoning_parser.py
================================================
# Copyright (c) OpenMMLab. All rights reserved.
# modified from https://github.com/vllm-project/vllm/tree/v0.7.3/vllm/entrypoints/openai/reasoning_parsers
import re
from typing import Optional, Sequence, Tuple, Union

from lmdeploy.serve.openai.protocol import ChatCompletionRequest, DeltaMessage

from lmdeploy.serve.openai.reasoning_parser import (
    ReasoningParser,
    ReasoningParserManager,
)


@ReasoningParserManager.register_module(name="deepseek-r1", force=True)
class DeepSeekR1ReasoningParser(ReasoningParser):
    """Reasoning parser for DeepSeek R1 model.

    The DeepSeek R1 model uses <think>...</think> tokens to denote reasoning text. This parser extracts the reasoning
    content from the model output.
    """

    def __init__(self, tokenizer: object):
        super().__init__(tokenizer)
        self.think_start_token = "<think>"
        self.think_end_token = "</think>"

        self.reasoning_regex = re.compile(
            rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL
        )

        if not self.model_tokenizer:
            raise ValueError(
                "The model tokenizer must be passed to the ReasoningParser "
                "constructor during construction."
            )

        self.think_start_token_id = self.vocab.get(self.think_start_token)
        self.think_end_token_id = self.vocab.get(self.think_end_token)
        if self.think_start_token_id is None or self.think_end_token_id is None:
            raise RuntimeError(
                "DeepSeek R1 reasoning parser could not locate think start/end "
                "tokens in the tokenizer!"
            )

    def extract_reasoning_content_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        **kwargs,
    ) -> Union[DeltaMessage, None]:
        """Instance method that should be implemented for extracting reasoning
        from an incomplete response; for use when handling reasoning calls and
        streaming.

        Has to be an instance method because  it requires state - the current tokens/diffs, but also the information
        about what has previously been parsed and extracted (see constructor)
        """
        if len(delta_token_ids) == 1:
            if delta_token_ids[0] == self.think_end_token_id:
                return DeltaMessage(content="")
            elif delta_token_ids[0] == self.think_start_token_id:
                return None

        # Check if <think> is present in previous or delta.
        # Keep compatibility with models that don't generate <think> tokens.
        if self.think_start_token_id in previous_token_ids:
            if self.think_end_token_id in delta_token_ids:
                # <think> in previous, </think> in delta,
                # extract reasoning content
                end_index = delta_text.find(self.think_end_token)
                reasoning_content = delta_text[:end_index]
                content = delta_text[end_index + len(self.think_end_token) :]
                return DeltaMessage(
                    reasoning_content=reasoning_content,
                    content=content if content else None,
                )
            elif self.think_end_token_id in previous_token_ids:
                # <think> in previous, </think> in previous,
                return DeltaMessage(content=delta_text)
            else:
                # <think> in previous, no </think> in previous or delta,
                # reasoning content continues
                return DeltaMessage(reasoning_content=delta_text)
        elif self.think_start_token_id in delta_token_ids:
            if self.think_end_token_id in delta_token_ids:
                # <think> in delta, </think> in delta, extract reasoning content
                start_index = delta_text.find(self.think_start_token)
                end_index = delta_text.find(self.think_end_token)
                reasoning_content = delta_text[
                    start_index + len(self.think_start_token) : end_index
                ]
                content = delta_text[end_index + len(self.think_end_token) :]
                return DeltaMessage(
                    reasoning_content=reasoning_content,
                    content=content if content else None,
                )
            else:
                # <think> in delta, no </think> in delta,
                # reasoning content continues
                return DeltaMessage(reasoning_content=delta_text)
        else:
            # No <think> in previous or delta, also need to check for </think>.
            # Because the model may have generated </think> without <think>
            # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
            if self.think_end_token_id in delta_token_ids:
                # </think> in delta with more tokens,
                # extract reasoning content and content
                end_index = delta_text.find(self.think_end_token)
                reasoning_content = delta_text[:end_index]
                content = delta_text[end_index + len(self.think_end_token) :]
                return DeltaMessage(
                    reasoning_content=reasoning_content,
                    content=content if content else None,
                )
            elif self.think_end_token_id in previous_token_ids:
                # </think> in previous, thinking content ends
                return DeltaMessage(content=delta_text)
            else:
                # no </think> in previous or delta, reasoning content continues
                return DeltaMessage(reasoning_content=delta_text)

    def extract_reasoning_content(
        self, model_output: str, request: ChatCompletionRequest, **kwargs
    ) -> Tuple[Optional[str], Optional[str]]:
        """Extract reasoning content from a complete model-generated string.

        Used for non-streaming responses where we have the entire model response
        available before sending to the client.

        Args:
            model_output (str): The model-generated string to extract reasoning content from.
            request (ChatCompletionRequest): he request object that was used to generate the model_output.

        Returns:
            reasoning_content (str | None): The reasoning content.
            final_output (str | None): The content.
        """
        # DeepSeek R1 doesn't generate <think> now.
        # Thus we assume the reasoning content is always at the start.
        # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
        if self.think_end_token not in model_output:
            return model_output, None
        else:
            # Add a start token if it's missing to keep compatibility.
            if self.think_start_token not in model_output:
                model_output = f"{self.think_start_token}{model_output}"
            # Use a regex to find the reasoning content
            reasoning_content = self.reasoning_regex.findall(model_output)[0]

            end_index = len(
                f"{self.think_start_token}{reasoning_content}{self.think_end_token}"
            )
            final_output = model_output[end_index:]

            if len(final_output) == 0:
                return reasoning_content, None

            return reasoning_content, final_output


================================================
FILE: gpt_server/model_handler/tool_parser.py
================================================
import json
import re
from typing import List, Literal, Optional

from loguru import logger
from pydantic import BaseModel, Field
import shortuuid
from vllm.entrypoints.openai.chat_completion.protocol import (
    ChatCompletionRequest,
    FunctionCall,
)

from vllm.tool_parsers import ToolParser, ToolParserManager


class ToolCall(BaseModel):
    """Tool call response."""

    index: Optional[int] = None
    id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}")
    type: Literal["function"] = "function"
    function: FunctionCall


class ExtractedToolCallInformation(BaseModel):
    # modified from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/openai/protocol.py#L1199
    # indicate if tools were called
    tools_called: bool
    # extracted tool calls
    tool_calls: List[ToolCall]
    # content - per OpenAI spec, content AND tool calls can be returned rarely
    # But some models will do this intentionally
    content: Optional[str] = None


@ToolParserManager.register_module(["qwen2_5"])
class Qwen2d5ToolParser(ToolParser):
    def __init__(self, tokenizer: object):
        super().__init__(tokenizer)
        self.position = 0
        self.tool_start_token = "<tool_call>"
        self.tool_end_token = "</tool_call>"
        self.pattern = r"<tool_call>(.*?)</tool_call>"

    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:
        text = model_output
        if self.tool_start_token in text and self.tool_end_token in text:
            logger.debug("tool_parse tool_start_token 在 text")
            # get tool_call in text
            match_result_list = re.findall(self.pattern, text, re.DOTALL)
            tool_calls = []
            index = -1
            for match_result in match_result_list:
                index += 1
                action = json.loads(match_result)
                name = action["name"]
                try:
                    arguments = json.dumps(action["arguments"], ensure_ascii=False)
                except KeyError:
                    arguments = json.dumps(action["parameters"], ensure_ascii=False)
                tool_calls.append(
                    ToolCall(
                        index=index,
                        function=FunctionCall(name=name, arguments=arguments),
                    )
                )

            # get text outside of tags
            if not text.startswith("<tool_call>"):
                text = text[: text.find("<tool_call>")]
            elif not text.endswith("</tool_call>"):
                text = text[text.rfind("</tool_call>") + len("</tool_call>") :]
            else:
                text = ""
            return ExtractedToolCallInformation(
                tools_called=True,
                tool_calls=tool_calls,
                content=text if len(text) > 0 else "",
            )
        elif self.tool_start_token in text or self.tool_end_token in text:
            # 如果 tool_start_token 不在 text 但是 tool_end_token 在text
            logger.debug("tool_parse tool_start_token 不在 text")
            pattern = r"\{[^{}]*\{[^{}]*\}[^{}]*\}|{[^{}]*}"
            match_result_list = re.findall(pattern, text, re.DOTALL)
            tool_calls = []
            tools_called = False
            index = -1
            # parameters
            for match_result in match_result_list:
                index += 1
                action = json.loads(match_result)
                name = action["name"]
                try:
                    arguments = json.dumps(action["arguments"], ensure_ascii=False)
                except KeyError:
                    arguments = json.dumps(action["parameters"], ensure_ascii=False)

                tool_calls.append(
                    ToolCall(
                        function=FunctionCall(name=name, arguments=arguments),
                        index=index,
                    )
                )
                tools_called = True
                # get text outside of tags

            return ExtractedToolCallInformation(
                tools_called=tools_called,
                tool_calls=tool_calls,
                content=text if len(text) > 0 else "",
            )
        logger.debug("tool_parse 无结果")
        return ExtractedToolCallInformation(
            tools_called=False, tool_calls=[], content=text
        )


def tool_parser(full_text: str, tool_parser_: ToolParser, tools, ret):
    try:
        request = ChatCompletionRequest(
            messages=[{"role": "user", "content": full_text}], tools=tools
        )
        tool_call_info = tool_parser_.extract_tool_calls(
            model_output=full_text, request=request
        )
        tools_called = tool_call_info.tools_called
        _, tool_calls_ = tool_call_info.content, tool_call_info.tool_calls
        tool_calls = []
        for index, i in enumerate(tool_calls_):
            tool_call = i.model_dump()
            if "index" not in tool_call:
                tool_call["index"] = index
            tool_calls.append(tool_call)

        # -----------------------------------
        ret["text"] = ""
        ret["tool_calls"] = tool_calls
        ret["finish_reason"] = (
            "tool_calls" if tools and tools_called else ret.get("finish_reason", "stop")
        )
        if tools:
            logger.info(
                f" 工具解析{'成功' if tools_called else '失败'}, tool_calls: {tool_calls}"
            )
        if not tools_called:
            return None
        return json.dumps(ret).encode() + b"\0"
    except Exception as e:
        logger.warning(f"Error in tool_parser: {e}")
        import traceback

        traceback.print_exc()
        return None


import json
import logging
from typing import Dict, List, Any, Optional


class ToolCallStreamProcessor:
    """
    处理流式tool_calls,只接收tool_calls部分数据
    """

    def __init__(self):
        # 存储所有工具调用的累积数据,按index索引
        self.tool_calls: Dict[int, Dict[str, Any]] = {}

    def process_chunk(self, tool_calls_data: List[Dict]) -> Optional[List[Dict]]:
        """
        处理tool_calls数据
        参数: tool_calls_data - 从delta中提取的tool_calls列表
        返回: 如果检测到完成则返回完整的工具调用,否则返回None
        """
        if not tool_calls_data:
            return None

        for tool_call in tool_calls_data:
            index = tool_call.get("index", 0)

            # 初始化新工具调用
            if index not in self.tool_calls:
                self.tool_calls[index] = {
                    "id": None,
                    "type": "function",
                    "function": {"name": None, "arguments": ""},
                }

            current = self.tool_calls[index]

            # 更新ID(只在第一个chunk中出现)
            if tool_call.get("id"):
                current["id"] = tool_call["id"]

            # 更新函数名(只在第一个chunk中出现)
            function_data = tool_call.get("function", {})
            if function_data.get("name"):
                current["function"]["name"] = function_data["name"]

            # 累积参数字符串
            if function_data.get("arguments"):
                current["function"]["arguments"] += function_data["arguments"]

        return None

    def get_completed_tool_calls(self) -> Optional[List[Dict]]:
        """
        获取所有完整的工具调用,并解析arguments JSON
        通常在收到finish_reason='tool_calls'后调用
        """
        if not self.tool_calls:
            return None

        completed_calls = []

        for index in sorted(self.tool_calls.keys()):
            call_data = self.tool_calls[index]

            # 检查是否完整
            if not call_data["id"] or not call_data["function"]["name"]:
                logging.warning(f"工具调用 {index} 不完整,跳过")
                continue

            # 解析arguments JSON
            args_str = call_data["function"]["arguments"]

            completed_calls.append(
                {
                    "id": call_data["id"],
                    "type": call_data["type"],
                    "function": {
                        "name": call_data["function"]["name"],
                        "arguments": args_str,
                    },
                }
            )

        return completed_calls if completed_calls else None

    def reset(self):
        """重置处理器"""
        self.tool_calls = {}


if __name__ == "__main__":
    from transformers import AutoTokenizer

    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_weather",
                "description": "Get the current weather in a given location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "location": {
                            "type": "string",
                            "description": "City and state, e.g., 'San Francisco, CA'",
                        },
                        "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
                    },
                    "required": ["location"],
                },
            },
        }
    ]
    glm_full_text = """Action: get_weather
Action Input: {"location": "Nanjing", "unit": "celsius"}"""
    qwen_full_text = """<tool_call>{"name": "get_weather", "arguments": {"location": "Nanjing", "unit": "celsius"}}</tool_call>"""
    qwen3coder_text = """
<tool_call>
<function=get_weather>
<parameter=location>
南京
</parameter>
<parameter=unit>
celsius
</parameter>
</function>
</tool_call>
"""
    tokenizer = AutoTokenizer.from_pretrained("/home/dev/model/Qwen/Qwen3___5-35B-A3B/")
    tool_parser_ = ToolParserManager.get_tool_parser("qwen2_5")(tokenizer)
    tool_parser(
        full_text=qwen_full_text, tool_parser_=tool_parser_, tools=tools, ret={}
    )


================================================
FILE: gpt_server/model_handler/utils.py
================================================


================================================
FILE: gpt_server/model_worker/__init__.py
================================================
from gpt_server.model_worker.utils import patch
import os

os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
patch()


def patch_infinity_embedder():
    import infinity_emb.transformer.embedder.sentence_transformer as embedder_module

    def patched_embedder_tokenize_lengths(self, sentences: list[str]) -> list[int]:
        """修复 SentenceTransformerPatched.tokenize_lengths 方法"""
        # 使用 tokenizer 的现代 API
        tks = self._infinity_tokenizer(
            sentences,
            add_special_tokens=False,
            truncation="longest_first",
            padding=False,
            return_length=True,
            return_attention_mask=False,
            return_token_type_ids=False,
        )

        # 提取长度信息
        if isinstance(tks, dict) and "length" in tks:
            return tks["length"].tolist()
        elif hasattr(tks, "encodings"):
            return [len(t.tokens) for t in tks.encodings]
        else:
            return [len(seq) for seq in tks["input_ids"]]

    embedder_module.SentenceTransformerPatched.tokenize_lengths = (
        patched_embedder_tokenize_lengths
    )


def patch_infinity_crossencoder():
    import infinity_emb.transformer.crossencoder.torch as crossencoder_module

    def patched_tokenize_lengths(self, sentences: list[str]) -> list[int]:
        """修复版本的 tokenize_lengths 方法,使用现代 transformers API"""
        # 使用 tokenizer 的 __call__ 方法
        tks = self._infinity_tokenizer(
            sentences,
            add_special_tokens=False,
            truncation="longest_first",
            padding=False,
            return_attention_mask=False,
            return_token_type_ids=False,
            return_length=True,
            return_tensors=None,
        )
        # 根据 transformers 版本返回长度
        if isinstance(tks, dict) and "length" in tks:
            # 新版本返回字典,包含 length 字段
            return tks["length"].tolist()
        elif hasattr(tks, "encodings"):
            # 旧版本可能有 encodings 属性
            return [len(t.tokens) for t in tks.encodings]
        else:
            # 通用方法:计算每个序列的 token 数量
            return [len(seq) for seq in tks["input_ids"]]

    crossencoder_module.CrossEncoderPatched.tokenize_lengths = patched_tokenize_lengths


patch_infinity_embedder()
patch_infinity_crossencoder()


================================================
FILE: gpt_server/model_worker/auto.py
================================================
import json
import traceback
from typing import List

from fastchat.constants import ErrorCode, SERVER_ERROR_MSG
from loguru import logger
import torch
from vllm.tool_parsers import ToolParserManager

from gpt_server.model_handler.tool_parser import tool_parser
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
from gpt_server.model_worker.utils import guess_tool_parser_by_model
from gpt_server.settings import get_model_config


class AutoWorker(ModelWorkerBase):
    def __init__(
        self,
        controller_addr: str,
        worker_addr: str,
        worker_id: str,
        model_path: str,
        model_names: List[str],
        limit_worker_concurrency: int,
        conv_template: str = None,  # type: ignore
    ):
        super().__init__(
            controller_addr,
            worker_addr,
            worker_id,
            model_path,
            model_names,
            limit_worker_concurrency,
            conv_template,
            model_type="AutoModelForCausalLM",
        )

        self.stop_words_ids = []

        self.stop = [
            self.tokenizer.decode(skip_word) for skip_word in self.stop_words_ids
        ]
        tool_parser_name = guess_tool_parser_by_model(model_path)
        model_config = get_model_config()

        # from https://github.com/xorbitsai/inference/blob/c70ea74fa820a613f8d577047ef1818da20a96b3/xinference/model/llm/llm_family_modelscope.json
        self.tool_parser = ToolParserManager.get_tool_parser(tool_parser_name)(
            self.tokenizer
        )
        logger.warning(
            f"已启动模型: {model_names[0]} |  工具解析器: {tool_parser_name} | 推理解析器: {model_config.reasoning_parser}"
        )

    async def generate_stream_gate(self, params):
        self.call_ct += 1
        try:
            tools = params.get("tools", None)
            api_type = params.get("api_type", "chat")
            full_text = ""
            ret = {}
            if api_type == "chat":
                async for ret in self.backend.stream_chat(params=params):
                    full_text += ret.get("text", "")
                    yield json.dumps(ret).encode() + b"\0"
                # ------ add tool_calls ------
                tool_parser_result = tool_parser(
                    full_text=full_text,
                    tool_parser_=self.tool_parser,
                    tools=tools,
                    ret=ret,
                )
                if tool_parser_result:
                    yield tool_parser_result
                # ------ add tool_calls ------
            else:
                async for ret in self.backend.stream_chat(params=params):
                    yield ret.encode() + b"\0"
        except torch.cuda.OutOfMemoryError as e:
            ret = {
                "text": f"{SERVER_ERROR_MSG}\n\n({e})",
                "error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
            }
            yield json.dumps(ret).encode() + b"\0"
        except (ValueError, RuntimeError) as e:
            traceback.print_exc()
            logger.info(e)
            ret = {
                "text": f"{SERVER_ERROR_MSG}\n\n({e})",
                "error_code": ErrorCode.INTERNAL_ERROR,
            }
            yield json.dumps(ret).encode() + b"\0"


if __name__ == "__main__":
    AutoWorker.run()


================================================
FILE: gpt_server/model_worker/base/__init__.py
================================================


================================================
FILE: gpt_server/model_worker/base/base_model_worker.py
================================================
import threading
import time
from typing import List

from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse, JSONResponse
import requests

from fastchat.conversation import Conversation
from fastchat.utils import pretty_print_semaphore


def build_logger():
    from loguru import logger

    return logger


worker = None
logger = None
WORKER_HEART_BEAT_INTERVAL = 6
app = FastAPI()


def heart_beat_worker(obj: "BaseModelWorker"):
    while True:
        time.sleep(WORKER_HEART_BEAT_INTERVAL)
        obj.send_heart_beat()


class BaseModelWorker:
    def __init__(
        self,
        controller_addr: str,
        worker_addr: str,
        worker_id: str,
        model_path: str,
        model_names: List[str],
        limit_worker_concurrency: int,
        conv_template: str = None,
        multimodal: bool = False,
    ):
        global logger, worker

        self.controller_addr = controller_addr
        self.worker_addr = worker_addr
        self.worker_id = worker_id
        if model_path.endswith("/"):
            model_path = model_path[:-1]
        self.model_names = model_names or [model_path.split("/")[-1]]
        self.limit_worker_concurrency = limit_worker_concurrency
        self.conv = self.make_conv_template(conv_template, model_path)
        self.conv.sep_style = int(self.conv.sep_style)
        self.multimodal = multimodal
        self.tokenizer = None
        self.context_len = None
        self.call_ct = 0
        self.semaphore = None

        self.heart_beat_thread = None

        if logger is None:
            logger = build_logger()
        if worker is None:
            worker = self

    def make_conv_template(
        self,
        conv_template: str = None,
        model_path: str = None,
    ) -> Conversation:
        """
        can be overrided to costomize the conversation template for different model workers.
        """
        from fastchat.conversation import get_conv_template
        from fastchat.model.model_adapter import get_conversation_template

        if conv_template:
            conv = get_conv_template(conv_template)
        else:
            conv = get_conversation_template(model_path)
        return conv

    def init_heart_beat(self):
        self.register_to_controller()
        self.heart_beat_thread = threading.Thread(
            target=heart_beat_worker,
            args=(self,),
            daemon=True,
        )
        self.heart_beat_thread.start()

    def register_to_controller(self):
        logger.info("Register to controller")

        url = self.controller_addr + "/register_worker"
        data = {
            "worker_addr": self.worker_addr,
            "check_heart_beat": True,
            "worker_status": self.get_status(),
            "multimodal": self.multimodal,
        }
        r = requests.post(url, json=data)
        assert r.status_code == 200

    def send_heart_beat(self):

        url = self.controller_addr + "/receive_heart_beat"

        while True:
            try:
                ret = requests.post(
                    url,
                    json={
                        "worker_addr": self.worker_addr,
                        "queue_length": self.get_queue_length(),
                    },
                    timeout=5,
                )
                exist = ret.json()["exist"]
                break
            except (requests.exceptions.RequestException, KeyError) as e:
                logger.error(f"heart beat error: {e}")
            time.sleep(5)

        if not exist:
            self.register_to_controller()

    def get_queue_length(self):
        if self.semaphore is None:
            return 0
        else:
            sempahore_value = (
                self.semaphore._value
                if self.semaphore._value is not None
                else self.limit_worker_concurrency
            )
            waiter_count = (
                0 if self.semaphore._waiters is None else len(self.semaphore._waiters)
            )
            return self.limit_worker_concurrency - sempahore_value + waiter_count

    def get_status(self):
        return {
            "model_names": self.model_names,
            "speed": 1,
            "queue_length": self.get_queue_length(),
        }

    def count_token(self, params):
        prompt = params["prompt"]

        try:
            input_ids = self.tokenizer(prompt).input_ids
            input_echo_len = len(input_ids)
        except TypeError:
            input_echo_len = self.tokenizer.num_tokens(prompt)

        ret = {
            "count": input_echo_len,
            "error_code": 0,
        }
        return ret

    def get_conv_template(self):
        return {"conv": self.conv}

    def generate_stream_gate(self, params):
        raise NotImplementedError

    def generate_gate(self, params):
        raise NotImplementedError

    def get_embeddings(self, params):
        raise NotImplementedError

    def classify(self, params):
        raise NotImplementedError

    def transcription(self, params):
        raise NotImplementedError

    def generate_voice_stream(self, params):
        raise NotImplementedError

    def get_image_output(self, params):
        raise NotImplementedError


================================================
FILE: gpt_server/model_worker/base/model_worker_base.py
================================================
import asyncio
from datetime import datetime
from typing import List
import json
import sys
import shutil
from abc import ABC
from contextlib import asynccontextmanager
from fastapi import BackgroundTasks, Request, FastAPI
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastchat.utils import SEQUENCE_LENGTH_KEYS
from loguru import logger
import os
from transformers import (
    AutoModel,
    AutoTokenizer,
    AutoModelForCausalLM,
    LlamaForCausalLM,
    AutoConfig,
    PreTrainedTokenizer,
)
import uuid
from gpt_server.utils import get_free_tcp_port, STATIC_DIR, local_ip
from gpt_server.model_worker.base.base_model_worker import BaseModelWorker
from gpt_server.model_handler.tool_parser import ToolCallStreamProcessor

worker = None
app = FastAPI()
os.makedirs(STATIC_DIR, exist_ok=True)
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")


def get_context_length_(config):
    """Get the context length of a model from a huggingface model config."""
    rope_scaling = getattr(config, "rope_scaling", None)
    if rope_scaling:
        try:
            rope_scaling_factor = config.rope_scaling["factor"]
        except KeyError:
            rope_scaling_factor = 1
    else:
        rope_scaling_factor = 1

    for key in SEQUENCE_LENGTH_KEYS:
        val = getattr(config, key, None)
        if val is not None:
            return int(rope_scaling_factor * val)
    return 2048


async def cleanup_static_files():
    """清理静态文件目录并重建"""
    await asyncio.sleep(10)  # 60分钟 = 3600秒
    logger.debug(f"{datetime.now()}  开始清理静态文件目录:{STATIC_DIR}")
    shutil.rmtree(STATIC_DIR, ignore_errors=True)
    os.makedirs(STATIC_DIR, exist_ok=True)
    logger.debug(f"{datetime.now()}  清理完成")
    await asyncio.sleep(10)  # 60分钟 = 3600秒


async def run_scheduler():
    """每60分钟执行一定时任务"""
    while True:
        await cleanup_static_files()
        await asyncio.sleep(60 * 60 * 12)  # 60分钟 = 3600秒


def pop_matching_tool(tools, tool_choice):
    # 获取目标function名称
    target_name = tool_choice["function"]["name"]

    # 遍历tools列表,查找匹配项
    for index, tool in enumerate(tools):
        if tool["function"]["name"] == target_name:
            return [tools.pop(index)]

    # 未找到时返回None
    return None


class ModelWorkerBase(BaseModelWorker, ABC):
    def __init__(
        self,
        controller_addr: str,
        worker_addr: str,
        worker_id: str,
        model_path: str,
        model_names: List[str],
        limit_worker_concurrency: int,
        conv_template: str = None,  # type: ignore
        model_type: str = "AutoModel",
        multimodal: bool = False,
    ):
        is_vision = False
        if model_type not in ["asr", "tts", "image"]:
            try:
                self.model_config = AutoConfig.from_pretrained(
                    model_path, trust_remote_code=True
                )
            except ValueError as e:
                logger.warning(e)
                self.model_config = {}
            self.max_position_embeddings = getattr(
                self.model_config, "max_position_embeddings", 512
            )
            # logger.info(f"模型配置:{self.model_config}")
            self.vision_config = getattr(self.model_config, "vision_config", None)
            is_vision = self.vision_config is not None
            if is_vision:
                multimodal = True
                logger.warning(f"{model_names[0]} 是多模态模型")
        super().__init__(
            controller_addr,
            worker_addr,
            worker_id,
            model_path,
            model_names,
            limit_worker_concurrency,
            conv_template,
            multimodal=multimodal,
        )
        os.environ["WORKER_NAME"] = self.__class__.__name__
        self.worker_name = self.__class__.__name__
        self.model_type = model_type
        self.model_path = model_path
        self.model = None
        self.backend = None
        self.chat_template = None
        self.vl_chat_template = None
        self.tokenizer: PreTrainedTokenizer | None = None
        self.load_model_tokenizer(model_path)
        self.context_len = self.get_context_length()
        logger.info(f"Loading the model {self.model_names} on worker {worker_id} ...")
        self.init_heart_beat()
        global worker
        if worker is None:
            worker = self
            logger.info("worker 已赋值")

    def preprocess_params(self, params: dict) -> dict:
        """预处理 params"""
        # ---------- 添加 chat_template 信息 ----------
        params["chat_template"] = self.chat_template
        # ---------- 添加多模态信息 ----------
        if hasattr(self, "vision_config") and self.vision_config:
            params["multimodal"] = True
            params["chat_template"] = self.vl_chat_template
        # ---------- 如果传入的是 str 则修改为messages ----------
        messages = params.get("messages", [])
        if isinstance(messages, str):
            messages = [{"role": "user", "content": messages}]
            params["messages"] = messages
        # ---------- 处理 工具,支持 tool_choice 的控制 ----------
        tool_choice = params.get("tool_choice", "none")
        tools = params.get("tools", None)
        params["extra_prompt"] = ""
        if tools:
            if tool_choice == "none":
                tools = None  # OK
            elif tool_choice == "auto":
                pass  # OK
            elif tool_choice == "required":
                params["extra_prompt"] = """<tool_call>\n{"name":"""
            elif isinstance(tool_choice, dict):
                tools = pop_matching_tool(tools=tools, tool_choice=tool_choice)
                tool_name = tool_choice["function"]["name"]
                params[
                    "extra_prompt"
                ] = f"""<tool_call>
{{"name": "{tool_name}", "arguments": 
    """
        params["tools"] = tools
        return params

    def get_context_length(
        self,
    ):
        """ "支持的最大 token 长度"""
        if self.model is None and self.backend is None:
            return 512
        return get_context_length_(self.model_config)

    def get_model_class(self):
        MODEL_CLASS = AutoModel
        if self.model_type == "LlamaForCausalLM":
            MODEL_CLASS = LlamaForCausalLM
            register = AutoModelForCausalLM._model_mapping.register
            register(LlamaForCausalLM.config_class, LlamaForCausalLM, exist_ok=True)
            MODEL_CLASS = AutoModelForCausalLM

        elif self.model_type == "AutoModel":
            MODEL_CLASS = AutoModel
        elif self.model_type == "AutoModelForCausalLM":
            MODEL_CLASS = AutoModelForCausalLM

        return MODEL_CLASS

    def load_model_tokenizer(self, model_path):
        """加载 模型 和 分词器 直接对 self.model 和 self.tokenizer 进行赋值"""
        if self.model_type in ["embedding", "asr", "tts", "image"]:
            return 1
        self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(
            model_path,
            trust_remote_code=True,
            encode_special_tokens=True,
        )
        if os.getenv("backend") == "vllm":
            from gpt_server.model_backend.vllm_backend import VllmBackend

            logger.info(f"{self.worker_name} 使用 vllm 后端")
            self.backend = VllmBackend(
                model_path=self.model_path, tokenizer=self.tokenizer
            )
        elif "sglang" in os.getenv("backend"):
            from gpt_server.model_backend.sglang_backend import SGLangBackend

            logger.info(f"{self.worker_name} 使用 SGLang 后端")
            self.backend = SGLangBackend(
                model_path=self.model_path, tokenizer=self.tokenizer
            )
        elif "lmdeploy" in os.getenv("backend"):
            from gpt_server.model_backend.lmdeploy_backend import LMDeployBackend

            logger.info(f"{self.worker_name} 使用 LMDeploy 后端")
            self.backend = LMDeployBackend(
                model_path=self.model_path, tokenizer=self.tokenizer
            )

        elif os.getenv("backend") == "hf":
            from gpt_server.model_backend.hf_backend import HFBackend

            logger.info(f"{self.worker_name} 使用 hf 后端")
            MODEL_CLASS = self.get_model_class()
            self.model = MODEL_CLASS.from_pretrained(
                model_path,
                trust_remote_code=True,
                torch_dtype="auto",
                device_map="auto",
            )

            self.model = self.model.eval()
            # 加载 HF 后端
            self.backend = HFBackend(tokenizer=self.tokenizer, model=self.model)
        logger.info("load_model_tokenizer 完成")

    async def generate_gate(self, params):
        full_text = ""
        full_tool_calls = None
        full_reasoning_content = ""
        tool_calls = None
        reasoning_content = ""
        processor = ToolCallStreamProcessor()
        async for ret in self.generate_stream_gate(params):
            full_text += json.loads(ret[:-1].decode()).get("text", "")
            tool_calls = json.loads(ret[:-1].decode()).get("tool_calls", None)
            reasoning_content = json.loads(ret[:-1].decode()).get(
                "reasoning_content", ""
            )
            if reasoning_content:
                full_reasoning_content += reasoning_content
            if tool_calls:
                processor.process_chunk(tool_calls)
        full_tool_calls = processor.get_completed_tool_calls()
        ret = json.loads(ret[:-1].decode())
        ret["text"] = full_text
        ret["tool_calls"] = full_tool_calls
        ret["reasoning_content"] = full_reasoning_content
        return ret

    @classmethod
    def get_worker(
        cls,
        model_path: str,
        worker_addr: str,
        controller_addr: str = "http://localhost:21001",
        worker_id: str = str(uuid.uuid4())[:8],
        model_names: List[str] = [""],
        limit_worker_concurrency: int = 1024,
        conv_template: str = None,  # type: ignore
    ):
        worker = cls(
            controller_addr,
            worker_addr,
            worker_id,
            model_path,
            model_names,
            limit_worker_concurrency,
            conv_template=conv_template,
        )
        return worker

    @classmethod
    def run(cls):
        import uvicorn
        import argparse

        parser = argparse.ArgumentParser()
        parser.add_argument("--num_gpus", type=int, default=1)
        parser.add_argument("--backend", type=str, default="hf")

        parser.add_argument(
            "--model_name_or_path", type=str, default="model_name_or_path"
        )
        parser.add_argument(
            "--model_names", type=lambda s: s.split(","), default="model_names"
        )
        parser.add_argument("--lora", type=str, default=None)
        parser.add_argument("--host", type=str, default="localhost")
        parser.add_argument(
            "--controller_address", type=str, default="http://localhost:21001"
        )
        parser.add_argument("--enable_prefix_caching", type=str, default="False")
        parser.add_argument("--enable_chunked_prefill", type=str, default="False")
        parser.add_argument("--dtype", type=str, default="auto")
        parser.add_argument("--max_model_len", type=str, default=None)
        parser.add_argument("--gpu_memory_utilization", type=str, default="0.8")
        # kv_cache_quant_policy
        parser.add_argument("--kv_cache_quant_policy", type=str, default="0")
        # vad_model
        parser.add_argument("--vad_model", type=str, default="")
        # punc_model
        parser.add_argument("--punc_model", type=str, default="")
        # log_level
        parser.add_argument("--log_level", type=str, default="WARNING")
        # task_type
        parser.add_argument("--task_type", type=str, default="auto")
        # limit_worker_concurrency
        parser.add_argument("--limit_worker_concurrency", type=int, default=1024)
        # port
        parser.add_argument("--port", type=int, default=None)
        # model_type
        parser.add_argument("--model_type", type=str, default="auto")
        # hf_overrides
        parser.add_argument("--hf_overrides", type=str, default="")
        # reasoning_parser
        parser.add_argument("--reasoning_parser", type=str, default="")
        parser.add_argument("--speculative_algorithm", type=str, default="")
        parser.add_argument("--speculative_num_steps", type=str, default="")
        # tool_call_parser
        parser.add_argument("--tool_call_parser", type=str, default="")
        # enforce_eager
        parser.add_argument("--enforce_eager", type=str, default="False")

        args = parser.parse_args()
        os.environ["num_gpus"] = str(args.num_gpus)
        if args.backend == "vllm":
            os.environ["backend"] = "vllm"
        elif args.backend == "hf":
            os.environ["backend"] = "hf"
        elif args.backend == "lmdeploy-pytorch":
            os.environ["backend"] = "lmdeploy-pytorch"
        elif args.backend == "lmdeploy-turbomind":
            os.environ["backend"] = "lmdeploy-turbomind"
        elif args.backend == "sglang":
            os.environ["backend"] = "sglang"

        if args.lora:
            os.environ["lora"] = args.lora
        if args.max_model_len:
            os.environ["max_model_len"] = args.max_model_len
        if args.vad_model:
            os.environ["vad_model"] = args.vad_model
        if args.punc_model:
            os.environ["punc_model"] = args.punc_model
        if args.hf_overrides:
            os.environ["hf_overrides"] = args.hf_overrides
        if args.reasoning_parser:
            os.environ["reasoning_parser"] = args.reasoning_parser
        if args.speculative_algorithm:
            os.environ["speculative_algorithm"] = args.speculative_algorithm
        if args.speculative_num_steps:
            os.environ["speculative_num_steps"] = args.speculative_num_steps
        if args.tool_call_parser:
            os.environ["tool_call_parser"] = args.tool_call_parser

        os.environ["model_type"] = args.model_type
        os.environ["enable_prefix_caching"] = args.enable_prefix_caching
        os.environ["enable_chunked_prefill"] = args.enable_chunked_prefill
        os.environ["gpu_memory_utilization"] = args.gpu_memory_utilization
        os.environ["kv_cache_quant_policy"] = args.kv_cache_quant_policy
        os.environ["dtype"] = args.dtype
        os.environ["log_level"] = args.log_level
        os.environ["task_type"] = args.task_type
        os.environ["enforce_eager"] = args.enforce_eager
        limit_worker_concurrency = int(args.limit_worker_concurrency)
        logger.remove(0)
        log_level = os.getenv("log_level", "WARNING")
        logger.add(sys.stderr, level=log_level, enqueue=True)

        host = args.host
        controller_address = args.controller_address
        if args.port:
            port = args.port
        else:
            port = get_free_tcp_port()
        os.environ["WORKER_PORT"] = str(port)
        os.environ["WORKER_HOST"] = str(local_ip)
        worker_addr = f"http://{host}:{port}"
        model_names = args.model_names
        logger.info(f"{model_names[0]} args: \n{args}")

        @asynccontextmanager
        async def lifespan(app: FastAPI):
            # Startup
            global worker
            asyncio.create_task(run_scheduler())
            worker = cls.get_worker(
                worker_addr=worker_addr,
                model_path=args.model_name_or_path,
                model_names=model_names,
                conv_template="chatglm3",
                controller_addr=controller_address,
                limit_worker_concurrency=limit_worker_concurrency,
            )
            yield
            # Shutdown
            # 优雅退出
            worker.backend.shutdown()

        app.router.lifespan_context = lifespan

        uvicorn.run(app, host=host, port=port)


def release_worker_semaphore():
    worker.semaphore.release()


def acquire_worker_semaphore():
    if worker.semaphore is None:
        worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
    return worker.semaphore.acquire()


def create_background_tasks(request_id):
    background_tasks = BackgroundTasks()
    background_tasks.add_task(release_worker_semaphore)

    return background_tasks


request_id = 0


def gen_request_id():
    global request_id
    request_id += 1
    return str(request_id)


@app.post("/worker_generate_stream")
async def api_generate_stream(request: Request):
    params = await request.json()
    await acquire_worker_semaphore()
    request_id = gen_request_id()
    params["request_id"] = request_id
    params["request"] = request
    logger.debug(f"params {params}")
    # 对 params 进行预处理
    params = worker.preprocess_params(params)
    generator = worker.generate_stream_gate(params)
    background_tasks = create_background_tasks(request_id)
    return StreamingResponse(generator, background=background_tasks)


@app.post("/worker_generate_voice_stream")
async def api_generate_stream(request: Request):
    params = await request.json()
    await acquire_worker_semaphore()
    request_id = gen_request_id()
    params["request_id"] = request_id
    params["request"] = request
    logger.debug(f"params {params}")
    generator = worker.generate_voice_stream(params)
    background_tasks = create_background_tasks(request_id)
    response_format = params["response_format"]
    content_type = {
        "mp3": "audio/mpeg",
        "opus": "audio/opus",
        "aac": "audio/aac",
        "flac": "audio/flac",
        "wav": "audio/wav",
        "pcm": "audio/pcm",
    }.get(response_format, f"audio/{response_format}")
    return StreamingResponse(
        generator,
        background=background_tasks,
        media_type=content_type,
        headers={
            "Content-Disposition": f"attachment; filename=speech.{response_format}",
            "X-Accel-Buffering": "no",
            "Cache-Control": "no-cache",
            "Transfer-Encoding": "chunked",
        },
    )


@app.post("/worker_generate")
async def api_generate(request: Request):
    params = await request.json()
    await acquire_worker_semaphore()
    request_id = gen_request_id()
    params["request_id"] = request_id
    params["request"] = request
    params.pop("prompt")
    logger.debug(f"params {params}")
    # 对 params 进行预处理
    params = worker.preprocess_params(params)
    output = await worker.generate_gate(params)
    release_worker_semaphore()

    return JSONResponse(output)


@app.post("/worker_get_status")
async def api_get_status(request: Request):
    return worker.get_status()


@app.post("/count_token")
async def api_count_token(request: Request):
    params = await request.json()
    return worker.count_token(params)


@app.post("/worker_get_conv_template")
async def api_get_conv(request: Request):
    return worker.get_conv_template()


@app.post("/model_details")
async def api_model_details(request: Request):
    return {"context_length": worker.context_len}


@app.post("/worker_get_embeddings")
async def api_get_embeddings(request: Request):
    params = await request.json()
    await acquire_worker_semaphore()
    logger.debug(f"params {params}")
    embedding = await worker.get_embeddings(params)
    release_worker_semaphore()
    return JSONResponse(content=embedding)


@app.post("/worker_get_image_output")
async def api_get_embeddings(request: Request):
    params = await request.json()
    await acquire_worker_semaphore()
    logger.debug(f"params {params}")
    result = await worker.get_image_output(params)
    release_worker_semaphore()
    return JSONResponse(content=result)


@app.post("/worker_get_classify")
async def api_get_classify(request: Request):
    params = await request.json()
    logger.debug(f"params {params}")
    await acquire_worker_semaphore()
    outputs = await worker.classify(params)
    release_worker_semaphore()
    return JSONResponse(content=outputs)


@app.post("/worker_get_transcription")
async def api_get_transcription(request: Request):
    params = await request.json()
    logger.debug(f"params {params}")
    await acquire_worker_semaphore()
    outputs = await worker.transcription(params)
    release_worker_semaphore()
    return JSONResponse(content=outputs)


================================================
FILE: gpt_server/model_worker/embedding_infinity.py
================================================
import os
from typing import List
import asyncio
from loguru import logger

from infinity_emb import AsyncEngineArray, EngineArgs, AsyncEmbeddingEngine
from infinity_emb.inference.select_model import get_engine_type_from_config
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
from gpt_server.model_worker.utils import get_embedding_mode, is_base64_image

label_to_category = {
    "S": "sexual",
    "H": "hate",
    "HR": "harassment",
    "SH": "self-harm",
    "S3": "sexual/minors",
    "H2": "hate/threatening",
    "V2": "violence/graphic",
    "V": "violence",
    "OK": "OK",
}


class EmbeddingWorker(ModelWorkerBase):
    def __init__(
        self,
        controller_addr: str,
        worker_addr: str,
        worker_id: str,
        model_path: str,
        model_names: List[str],
        limit_worker_concurrency: int,
        conv_template: str = None,  # type: ignore
    ):
        super().__init__(
            controller_addr,
            worker_addr,
            worker_id,
            model_path,
            model_names,
            limit_worker_concurrency,
            conv_template,
            model_type="embedding",
        )
        if os.environ.get("CUDA_VISIBLE_DEVICES", "") == "":
            device = "cpu"
        else:
            device = "cuda"
        logger.warning(f"使用{device}加载...")
        model_type = getattr(self.model_config, "model_type", None)
        bettertransformer = False
        # TODO bettertransformer = True  transformer 出问题
        # if model_type is not None and "deberta" in model_type:
        #     bettertransformer = False
        engine_args = EngineArgs(
            model_name_or_path=model_path,
            engine="torch",
            embedding_dtype="float32",
            dtype="float32",
            device=device,
            bettertransformer=bettertransformer,
        )
        self.mode = get_embedding_mode(model_path=model_path)
        self.engine: AsyncEmbeddingEngine = AsyncEngineArray.from_args([engine_args])[0]
        loop = asyncio.get_running_loop()
        loop.create_task(self.engine.astart())
        logger.warning(f"模型:{model_names[0]}")
        logger.warning(f"正在使用 {self.mode} 模型...")

    async def astart(self):
        await self.engine.astart()

    async def get_embeddings(self, params):
        self.call_ct += 1
        ret = {"embedding": [], "token_num": 0}
        texts: list = params["input"]
        embedding = []
        usage = 0
        if self.mode == "embedding":
            texts = list(map(lambda x: x.replace("\n", " "), texts))
            embeddings, usage = await self.engine.embed(sentences=texts)
            embedding = [embedding.tolist() for embedding in embeddings]
        elif self.mode == "rerank":
            query = params.get("query", None)
            ranking, usage = await self.engine.rerank(
                query=query, docs=texts, raw_scores=False
            )
            ranking = [
                {
                    "index": i.index,
                    "relevance_score": i.relevance_score,
                    "document": i.document,
                }
                for i in ranking
            ]
            ranking.sort(key=lambda x: x["index"])
            embedding = [
                [round(float(score["relevance_score"]), 6)] for score in ranking
            ]
        elif self.mode == "image":
            if (
                isinstance(texts[0], bytes)
                or "http" in texts[0]
                or is_base64_image(texts[0])
            ):
                embeddings, usage = await self.engine.image_embed(images=texts)
            else:
                embeddings, usage = await self.engine.embed(sentences=texts)

            embedding = [embedding.tolist() for embedding in embeddings]
        ret["embedding"] = embedding
        ret["token_num"] = usage
        return ret

    async def classify(self, params):
        logger.info(f"params {params}")
        logger.info(f"worker_id: {self.worker_id}")
        self.call_ct += 1
        ret = {}
        texts = params["input"]
        threshold = params["threshold"]
        scores, usage = await self.engine.classify(sentences=texts, raw_scores=False)
        results = []
        flagged = True
        for item in scores:
            categories_flags = {}
            category_scores = {}
            for entry in item:
                label = entry["label"]  # 原始的laebl
                label = label_to_category.get(
                    label, label
                )  # 将原始的label转换为category, 如果没有对应的category, 则使用原始的label
                score = entry["score"]
                # 更新类别标志和分数
                category_scores[label] = score
                # 如果分数高于某个阈值,标记为 flagged
                categories_flags[label] = False
                if score > threshold:
                    categories_flags[label] = True
            results.append(
                {
                    "flagged": flagged,
                    "categories": categories_flags,
                    "category_scores": category_scores,
                }
            )
        ret["results"] = results
        ret["token_num"] = usage
        return ret


if __name__ == "__main__":
    EmbeddingWorker.run()


================================================
FILE: gpt_server/model_worker/embedding_sentence_transformers.py
================================================
import os
from typing import List

from loguru import logger
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
from gpt_server.model_worker.utils import (
    PoolingModel,
)


class EmbeddingWorker(ModelWorkerBase):
    def __init__(
        self,
        controller_addr: str,
        worker_addr: str,
        worker_id: str,
        model_path: str,
        model_names: List[str],
        limit_worker_concurrency: int,
        conv_template: str = None,  # type: ignore
    ):
        super().__init__(
            controller_addr,
            worker_addr,
            worker_id,
            model_path,
            model_names,
            limit_worker_concurrency,
            conv_template,
            model_type="embedding",
        )
        if os.environ.get("CUDA_VISIBLE_DEVICES", "") == "":
            device = "cpu"
        else:
            device = "cuda"
        logger.warning(f"使用{device}加载...")
        self.pool_model = PoolingModel(model_path=model_path)
        logger.warning(f"模型:{model_names[0]}")

    async def get_embeddings(self, params):
        self.call_ct += 1
        texts = params["input"]
        query = params.get("query", None)
        ret = self.pool_model.pooling(query=query, documents=texts)
        return ret


if __name__ == "__main__":
    EmbeddingWorker.run()


================================================
FILE: gpt_server/model_worker/embedding_v2.py
================================================
import os
from typing import List
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
import sentence_transformers
import asyncio
from asyncio import Queue
from loguru import logger


class EmbeddingWorker(ModelWorkerBase):
    def __init__(
        self,
        controller_addr: str,
        worker_addr: str,
        worker_id: str,
        model_path: str,
        model_names: List[str],
        limit_worker_concurrency: int,
        conv_template: str = None,  # type: ignore
    ):
        super().__init__(
            controller_addr,
            worker_addr,
            worker_id,
            model_path,
            model_names,
            limit_worker_concurrency,
            conv_template,
            model_type="embedding",
        )
        if os.environ.get("CUDA_VISIBLE_DEVICES", "") == "":
            device = "cpu"
        else:
            device = "cuda"
        logger.info(f"使用{device}加载...")
        model_kwargs = {"device": device}
        self.request_queue: Queue = Queue()
        self.loop = asyncio.get_running_loop()

        self.worker_tasks = [
            self.loop.create_task(self.batch_processor()) for _ in range(1)
        ]
        # -------------------------------------------------------------------------
        self.batch_size = 64
        self.encode_kwargs = {
            "normalize_embeddings": True,
            "batch_size": self.batch_size,
        }
        self.mode = "embedding"
        # rerank
        for model_name in model_names:
            if "rerank" in model_name:
                self.mode = "rerank"
                break
        if self.mode == "rerank":
            self.client = sentence_transformers.CrossEncoder(
                model_name=model_path, **model_kwargs
            )
            logger.warning("正在使用 rerank 模型...")
        elif self.mode == "embedding":
            self.client = sentence_transformers.SentenceTransformer(
                model_path, **model_kwargs
            )
            logger.warning("正在使用 embedding 模型...")
        self.warm_up()

    def warm_up(self):
        logger.info("开始 warm_up")
        if self.mode == "embedding":
            self.client.encode(sentences=["你是谁"] * 10)
        elif self.mode == "rerank":
            self.client.predict(sentences=[["你好", "你好啊"]] * 10)

    async def batch_processor(self):
        logger.warning("进入batch_processor")
        while True:
            requests = []
            batch_size = 0
            try:
                while batch_size < self.batch_size:
                    # 等待 100ms
                    request = await asyncio.wait_for(
                        self.request_queue.get(), timeout=0.1
                    )
                    requests.append(request)
                    batch_size += len(request[0]["input"])

            except asyncio.TimeoutError as e:
                pass
            if requests:
                try:
                    all_input = [request[0]["input"] for request in requests]
                    futures = [request[1] for request in requests]

                    if self.mode == "embedding":
                        # 开始进行动态组批
                        ## 1. 展平text
                        # all_input = [ List[str] ]
                        # request[0] ---> params
                        all_texts = [text for input in all_input for text in input]
                        logger.debug(all_texts)
                        embeddings = self.client.encode(
                            all_texts, **self.encode_kwargs
                        ).tolist()

                    elif self.mode == "rerank":
                        # all_input = [ List[str] ]
                        # all_query = [str]
                        # all_texts = [str]
                        # request[0] ---> params
                        all_query = [request[0]["query"] for request in requests]
                        all_sentence_pairs = []

                        for query, inps in zip(all_query, all_input):
                            sentence_pairs = [[query, inp] for inp in inps]

                            all_sentence_pairs.extend(sentence_pairs)
                        logger.debug(all_sentence_pairs)
                        scores = self.client.predict(all_sentence_pairs)
                        embeddings = [[float(score)] for score in scores]

                    idx = 0
                    for future, request in zip(futures, requests):
                        num_texts = len(request[0]["input"])
                        future.set_result(embeddings[idx : idx + num_texts])
                        idx += num_texts
                except Exception as e:
                    logger.exception(e)
                    for future in futures:
                        future.set_exception(e)

    async def add_request(self, params: dict, future: asyncio.Future):

        await self.request_queue.put(item=(params, future))

    async def aembed(self, params: dict, future: asyncio.Future):
        await self.add_request(params, future)

    async def rerank(self, params: dict, future: asyncio.Future):
        await self.add_request(params, future)

    async def get_embeddings(self, params):
        self.call_ct += 1
        ret = {"embedding": [], "token_num": 0}
        texts = params["input"]
        loop = asyncio.get_running_loop()
        future = loop.create_future()
        if self.mode == "embedding":
            token_num = 0
            await self.aembed(params, future)
            embedding = await future
        elif self.mode == "rerank":
            token_num = 0
            await self.rerank(params, future)
            embedding = await future
        ret["embedding"] = embedding
        ret["token_num"] = token_num
        return ret


if __name__ == "__main__":
    EmbeddingWorker.run()
    asyncio.run()


================================================
FILE: gpt_server/model_worker/embedding_vllm.py
================================================
import os
from typing import List
from loguru import logger

from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
from gpt_server.model_worker.utils import get_embedding_mode
import numpy as np
from vllm import LLM, EmbeddingRequestOutput, ScoringRequestOutput
from gpt_server.settings import get_model_config

label_to_category = {
    "S": "sexual",
    "H": "hate",
    "HR": "harassment",
    "SH": "self-harm",
    "S3": "sexual/minors",
    "H2": "hate/threatening",
    "V2": "violence/graphic",
    "V": "violence",
    "OK": "OK",
}


def template_format(queries: List[str], documents: List[str]):
    model_config = get_model_config()
    hf_overrides = model_config.hf_overrides
    if hf_overrides:
        if hf_overrides["architectures"][0] == "Qwen3ForSequenceClassification":
            logger.info("使用 Qwen3ForSequenceClassification 模板格式化...")
            prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
            suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
            instruction = "Given a web search query, retrieve relevant passages that answer the query"

            query_template = f"{prefix}<Instruct>: {instruction}\n<Query>: {{query}}\n"
            document_template = f"<Document>: {{doc}}{suffix}"
            queries = [query_template.format(query=query) for query in queries]
            documents = [document_template.format(doc=doc) for doc in documents]
            return queries, documents
    return queries, documents


class EmbeddingWorker(ModelWorkerBase):
    def __init__(
        self,
        controller_addr: str,
        worker_addr: str,
        worker_id: str,
        model_path: str,
        model_names: List[str],
        limit_worker_concurrency: int,
        conv_template: str = None,  # type: ignore
    ):
        super().__init__(
            controller_addr,
            worker_addr,
            worker_id,
            model_path,
            model_names,
            limit_worker_concurrency,
            conv_template,
            model_type="embedding",
        )
        model_config = get_model_config()
        hf_overrides = model_config.hf_overrides
        self.mode = get_embedding_mode(model_path=model_path)
        runner = "auto"
        if self.model == "rerank":
            runner = "pooling"
        self.engine = LLM(
            model=model_path,
            tensor_parallel_size=model_config.num_gpus,
            max_model_len=model_config.max_model_len,
            gpu_memory_utilization=model_config.gpu_memory_utilization,
            enable_prefix_caching=model_config.enable_prefix_caching,
            runner=runner,
            hf_overrides=hf_overrides,
        )

        logger.warning(f"模型:{model_names[0]}")
        logger.warning(f"正在使用 {self.mode} 模型...")

    async def get_embeddings(self, params):
        self.call_ct += 1
        ret = {"embedding": [], "token_num": 0}
        texts: list = params["input"]
        embedding = []
        if self.mode == "embedding":
            texts = list(map(lambda x: x.replace("\n", " "), texts))
            # ----------
            outputs: list[EmbeddingRequestOutput] = self.engine.embed(
                texts,
                truncate_prompt_tokens=self.max_position_embeddings - 4,
            )
            embedding = [o.outputs.embedding for o in outputs]
            embeddings_np = np.array(embedding)
            # ------ L2归一化(沿axis=1,即对每一行进行归一化)-------
            norm = np.linalg.norm(embeddings_np, ord=2, axis=1, keepdims=True)
            normalized_embeddings_np = embeddings_np / norm
            embedding = normalized_embeddings_np.tolist()
        elif self.mode == "rerank":
            query = params.get("query", None)
            data_1 = [query] * len(texts)
            data_2 = texts
            data_1, data_2 = template_format(queries=data_1, documents=data_2)
            scores: list[ScoringRequestOutput] = self.engine.score(data_1, data_2)
            embedding = [[score.outputs.score] for score in scores]

        ret["embedding"] = embedding
        return ret


if __name__ == "__main__":
    EmbeddingWorker.run()


================================================
FILE: gpt_server/model_worker/flux.py
================================================
import asyncio

import io
import os
from typing import List
import uuid
from loguru import logger
import shortuuid
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
from gpt_server.model_worker.utils import pil_to_base64
import torch
from diffusers import FluxPipeline
from gpt_server.utils import STATIC_DIR

root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))


class FluxWorker(ModelWorkerBase):
    def __init__(
        self,
        controller_addr: str,
        worker_addr: str,
        worker_id: str,
        model_path: str,
        model_names: List[str],
        limit_worker_concurrency: int,
        conv_template: str = None,  # type: ignore
    ):
        super().__init__(
            controller_addr,
            worker_addr,
            worker_id,
            model_path,
            model_names,
            limit_worker_concurrency,
            conv_template,
            model_type="image",
        )
        backend = os.environ["backend"]
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.pipe = FluxPipeline.from_pretrained(
            model_path, torch_dtype=torch.bfloat16
        ).to(self.device)

        logger.warning(f"模型:{model_names[0]}")

    async def get_image_output(self, params):
        prompt = params["prompt"]
        response_format = params.get("response_format", "b64_json")
        image = self.pipe(
            prompt,
            height=1024,
            width=1024,
            guidance_scale=3.5,
            num_inference_steps=50,
            max_sequence_length=512,
            generator=torch.Generator(self.device).manual_seed(0),
        ).images[0]
        result = {}
        if response_format == "b64_json":
            # Convert PIL image to base64
            base64 = pil_to_base64(pil_img=image)
            result = {
                "created": shortuuid.random(),
                "data": [{"b64_json": base64}],
                "usage": {
                    "total_tokens": 0,
                    "input_tokens": 0,
                    "output_tokens": 0,
                    "input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
                },
            }
            return result
        elif response_format == "url":
            # 生成唯一文件名(避免冲突)
            file_name = str(uuid.uuid4()) + ".png"
            save_path = STATIC_DIR / file_name
            image.save(save_path, format="PNG")
            WORKER_PORT = os.environ["WORKER_PORT"]
            WORKER_HOST = os.environ["WORKER_HOST"]
            url = f"http://{WORKER_HOST}:{WORKER_PORT}/static/{file_name}"
            result = {
                "created": shortuuid.random(),
                "data": [{"url": url}],
                "usage": {
                    "total_tokens": 0,
                    "input_tokens": 0,
                    "output_tokens": 0,
                    "input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
                },
            }
        return result


if __name__ == "__main__":
    FluxWorker.run()


================================================
FILE: gpt_server/model_worker/funasr.py
================================================
import os
from typing import List
import base64
from loguru import logger
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
from io import BytesIO


class FunASRWorker(ModelWorkerBase):
    def __init__(
        self,
        controller_addr: str,
        worker_addr: str,
        worker_id: str,
        model_path: str,
        model_names: List[str],
        limit_worker_concurrency: int,
        conv_template: str = None,  # type: ignore
    ):
        super().__init__(
            controller_addr,
            worker_addr,
            worker_id,
            model_path,
            model_names,
            limit_worker_concurrency,
            conv_template,
            model_type="asr",
        )
        if os.environ.get("CUDA_VISIBLE_DEVICES", "") == "":
            device = "cpu"
        else:
            device = "cuda"
        logger.warning(f"使用{device}加载...")
        vad_model = os.environ.get("vad_model", None)
        punc_model = os.environ.get("punc_model", None)
        self.model = AutoModel(
            model=model_path,
            vad_model=vad_model,
            punc_model=punc_model,
            vad_kwargs={"max_single_segment_time": 30000},
            device="cuda",
        )
        logger.warning(f"模型:{model_names[0]}")

    async def transcription(self, params):
        file_input = base64.b64decode(params["file"])  # Base64 → bytes
        file_input = BytesIO(file_input)
        ret = {}
        res = self.model.generate(
            input=file_input,
            cache={},
            language="auto",  # "zn", "en", "yue", "ja", "ko", "nospeech"
            use_itn=True,
            batch_size_s=60,
            merge_vad=True,  #
            merge_length_s=15,
        )
        text = rich_transcription_postprocess(res[0]["text"])
        ret["text"] = text
        return ret


if __name__ == "__main__":
    FunASRWorker.run()


================================================
FILE: gpt_server/model_worker/qwen_image.py
================================================
import asyncio
import os
from typing import List
import uuid
from loguru import logger
import shortuuid
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
from gpt_server.model_worker.utils import pil_to_base64
import torch
from diffusers import DiffusionPipeline
from gpt_server.utils import STATIC_DIR

root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))

positive_magic = {
    "en": ", Ultra HD, 4K, cinematic composition.",  # for english prompt
    "zh": ", 超清,4K,电影级构图.",  # for chinese prompt
}

aspect_ratios = {
    "1:1": (1328, 1328),
    "16:9": (1664, 928),
    "9:16": (928, 1664),
    "4:3": (1472, 1140),
    "3:4": (1140, 1472),
    "3:2": (1584, 1056),
    "2:3": (1056, 1584),
}

width, height = aspect_ratios["16:9"]
import re


def contains_chinese(text):
    pattern = re.compile(r"[\u4e00-\u9fff]")
    return bool(pattern.search(text))


class QwenImageWorker(ModelWorkerBase):
    def __init__(
        self,
        controller_addr: str,
        worker_addr: str,
        worker_id: str,
        model_path: str,
        model_names: List[str],
        limit_worker_concurrency: int,
        conv_template: str = None,  # type: ignore
    ):
        super().__init__(
            controller_addr,
            worker_addr,
            worker_id,
            model_path,
            model_names,
            limit_worker_concurrency,
            conv_template,
            model_type="image",
        )
        backend = os.environ["backend"]
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.pipe = DiffusionPipeline.from_pretrained(
            model_path, torch_dtype=torch.bfloat16
        ).to(self.device)

        logger.warning(f"模型:{model_names[0]}")

    async def get_image_output(self, params):
        self.call_ct += 1
        prompt = params["prompt"]
        response_format = params.get("response_format", "b64_json")
        inputs = {
            "prompt": prompt,
            "negative_prompt": " ",
            "num_inference_steps": 50,
            "true_cfg_scale": 4.0,
            "generator": torch.Generator(self.device).manual_seed(0),
        }
        size = params.get("size", None)
        if size:
            size_split = size.split("x")
            width, height = int(size_split[0]), int(size_split[1])
            inputs.update({"width": width, "height": height})
        output = await asyncio.to_thread(self.pipe, **inputs)
        image = output.images[0]
        result = {}
        if response_format == "b64_json":
            # Convert PIL image to base64
            base64 = pil_to_base64(pil_img=image)
            result = {
                "created": shortuuid.random(),
                "data": [{"b64_json": base64}],
                "usage": {
                    "total_tokens": 0,
                    "input_tokens": 0,
                    "output_tokens": 0,
                    "input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
                },
            }
            return result
        elif response_format == "url":
            # 生成唯一文件名(避免冲突)
            file_name = str(uuid.uuid4()) + ".png"
            save_path = STATIC_DIR / file_name
            image.save(save_path, format="PNG")
            WORKER_PORT = os.environ["WORKER_PORT"]
            WORKER_HOST = os.environ["WORKER_HOST"]
            url = f"http://{WORKER_HOST}:{WORKER_PORT}/static/{file_name}"
            result = {
                "created": shortuuid.random(),
                "data": [{"url": url}],
                "usage": {
                    "total_tokens": 0,
                    "input_tokens": 0,
                    "output_tokens": 0,
                    "input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
                },
            }
        return result


if __name__ == "__main__":
    QwenImageWorker.run()


================================================
FILE: gpt_server/model_worker/qwen_image_edit.py
================================================
import asyncio

import os
from typing import List
import uuid
from loguru import logger
import shortuuid
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
from gpt_server.model_worker.utils import (
    pil_to_base64,
    load_base64_or_url,
    bytesio2image,
)
from gpt_server.utils import STATIC_DIR
import torch
from diffusers import QwenImageEditPlusPipeline

root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))


class QwenImageEditWorker(ModelWorkerBase):
    def __init__(
        self,
        controller_addr: str,
        worker_addr: str,
        worker_id: str,
        model_path: str,
        model_names: List[str],
        limit_worker_concurrency: int,
        conv_template: str = None,  # type: ignore
    ):
        super().__init__(
            controller_addr,
            worker_addr,
            worker_id,
            model_path,
            model_names,
            limit_worker_concurrency,
            conv_template,
            model_type="image",
        )
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.pipe = QwenImageEditPlusPipeline.from_pretrained(model_path)
        self.pipe.to(torch.bfloat16)
        self.pipe.to(self.device)
        self.pipe.set_progress_bar_config(disable=None)
        logger.warning(f"模型:{model_names[0]}")

    async def get_image_output(self, params):
        prompt = params["prompt"]
        response_format = params.get("response_format", "b64_json")
        image: list = params["image"]
        image = [bytesio2image(await load_base64_or_url(img)) for img in image]
        # bytes_io = await load_base64_or_url(params["image"])
        # image = bytesio2image(bytes_io)
        inputs = {
            "image": image,
            "prompt": prompt,
            "negative_prompt": None,
            "generator": torch.manual_seed(0),
            "true_cfg_scale": 4.0,
            "negative_prompt": " ",
            "num_inference_steps": 40,
        }
        with torch.inference_mode():
            output = await asyncio.to_thread(self.pipe, **inputs)
            image = output.images[0]

        result = {}
        if response_format == "b64_json":
            # Convert PIL image to base64
            base64 = pil_to_base64(pil_img=image)
            result = {
                "created": shortuuid.random(),
                "data": [{"b64_json": base64}],
                "usage": {
                    "total_tokens": 0,
                    "input_tokens": 0,
                    "output_tokens": 0,
                    "input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
                },
            }
            return result
        elif response_format == "url":
            # 生成唯一文件名(避免冲突)
            file_name = str(uuid.uuid4()) + ".png"
            save_path = STATIC_DIR / file_name
            image.save(save_path, format="PNG")
            WORKER_PORT = os.environ["WORKER_PORT"]
            WORKER_HOST = os.environ["WORKER_HOST"]
            url = f"http://{WORKER_HOST}:{WORKER_PORT}/static/{file_name}"
            result = {
                "created": shortuuid.random(),
                "data": [{"url": url}],
                "usage": {
                    "total_tokens": 0,
                    "input_tokens": 0,
                    "output_tokens": 0,
                    "input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
                },
            }
        return result


if __name__ == "__main__":
    QwenImageEditWorker.run()


================================================
FILE: gpt_server/model_worker/spark_tts.py
================================================
import asyncio

import os
from typing import List
from loguru import logger
from gpt_server.model_handler.pitch import pitch_flashtts

pitch_flashtts()
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
from gpt_server.model_worker.utils import load_base64_or_url
from flashtts.engine import AutoEngine
from flashtts.server.utils.audio_writer import StreamingAudioWriter

root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# os.environ["VLLM_USE_V1"] = "0"


class SparkTTSWorker(ModelWorkerBase):
    def __init__(
        self,
        controller_addr: str,
        worker_addr: str,
        worker_id: str,
        model_path: str,
        model_names: List[str],
        limit_worker_concurrency: int,
        conv_template: str = None,  # type: ignore
    ):
        super().__init__(
            controller_addr,
            worker_addr,
            worker_id,
            model_path,
            model_names,
            limit_worker_concurrency,
            conv_template,
            model_type="tts",
        )
        backend = os.environ["backend"]
        gpu_memory_utilization = float(os.getenv("gpu_memory_utilization", 0.6))
        self.engine = AutoEngine(
            model_path=model_path,
            max_length=32768,
            llm_device="auto",
            tokenizer_device="auto",
            detokenizer_device="auto",
            backend=backend,
            wav2vec_attn_implementation="sdpa",  # 使用flash attn加速wav2vec
            llm_gpu_memory_utilization=gpu_memory_utilization,
            seed=0,
        )
        loop = asyncio.get_running_loop()
        # ------------- 添加声音 -------------
        loop.create_task(
            self.engine.add_speaker(
                "新闻联播女声",
                audio=os.path.join(
                    root_dir, "assets/audio_data/roles/新闻联播女声/女声.wav"
                ),
            )
        )
        logger.warning(f"模型:{model_names[0]}")
        logger.info(f"list_speakers: {self.engine.list_speakers()}")

    # 这个是模型主要的方法
    async def generate_voice_stream(self, params):
        if self.engine.engine_name != "spark":
            raise ValueError("仅Spark-TTS支持`generate_voice_stream`功能.")
        async for chunk_data in self.stream_async(params=params):
            yield chunk_data

    async def stream_async(self, params):
        text = params["text"]
        voice = params.get("voice", "新闻联播女声")
        response_format = params["response_format"]
        speed = params["speed"]
        pitch = params["pitch"]
        audio_writer = StreamingAudioWriter(
            format=response_format, sample_rate=self.engine.SAMPLE_RATE
        )
        generator = None
        if voice in self.engine.list_speakers():
            generator = self.engine.speak_stream_async(
                name=voice,
                text=text,
                length_threshold=50,
                window_size=50,
                speed=speed,
                pitch=pitch,
            )
        else:  # clone
            reference_audio = await load_base64_or_url(voice)
            generator = self.engine.clone_voice_stream_async(
                text=text,
                reference_audio=reference_audio,
                length_threshold=50,
                window_size=50,
                speed=speed,
                pitch=pitch,
            )
        async for chunk_data in generator:
            audio = audio_writer.write_chunk(chunk_data, finalize=False)
            yield audio
        end_chunk_data = audio_writer.write_chunk(finalize=True)
        yield end_chunk_data
        logger.debug(f"end_chunk_data 长度:{len(end_chunk_data)}")


if __name__ == "__main__":
    SparkTTSWorker.run()


================================================
FILE: gpt_server/model_worker/utils.py
================================================
import httpx
from loguru import logger
from fastapi import HTTPException
import base64
import io
import os
from PIL import Image
import re
import torch
from transformers import AutoConfig
from transformers import AutoModel
import sentence_transformers


def is_base64_image(data_string):
    pattern = r"^data:image\/[a-zA-Z+]+;base64,[A-Za-z0-9+/]+=*$"
    return bool(re.match(pattern, data_string))


# 转换为Base64
def pil_to_base64(pil_img: Image.Image, format: str = "PNG"):
    buffered = io.BytesIO()
    pil_img.save(buffered, format=format)  # 明确指定PNG格式
    return base64.b64encode(buffered.getvalue()).decode("utf-8")


def _extract_base64(data_url: str):
    """从Data URL中提取纯Base64数据"""
    return data_url.split(",", 1)[-1]  # 从第一个逗号后分割


async def _get_bytes_from_url(url: str) -> bytes:
    async with httpx.AsyncClient() as client:
        response = await client.get(url)
        if response.status_code != 200:
            raise HTTPException(status_code=400, detail="无法从指定 URL 下载数据")
        return response.content


def bytesio2image(bytes_io: io.BytesIO) -> Image.Image:
    return Image.open(bytes_io)


def bytes2image(bytes_: bytes) -> Image.Image:
    bytes_io = io.BytesIO(bytes_)
    return Image.open(bytes_io)


async def load_base64_or_url(base64_or_url) -> io.BytesIO:
    # 根据 reference_audio 内容判断读取方式
    if base64_or_url.startswith("http://") or base64_or_url.startswith("https://"):
        audio_bytes = await _get_bytes_from_url(base64_or_url)
    else:
        try:
            if "data:" in base64_or_url:
                base64_or_url = _extract_base64(data_url=base64_or_url)
            audio_bytes = base64.b64decode(base64_or_url)
        except Exception as e:
            logger.warning("无效的 base64 数据: " + str(e))
            raise HTTPException(status_code=400, detail="无效的 base64 数据: " + str(e))
    # 利用 BytesIO 包装字节数据
    try:
        bytes_io = io.BytesIO(audio_bytes)
    except Exception as e:
        logger.warning("读取数据失败: " + str(e))
        raise HTTPException(status_code=400, detail="读取数据失败: " + str(e))
    return bytes_io


def guess_tool_parser_by_model(model_path: str) -> str:
    """根据模型路径猜测工具解析器"""
    model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    architectures = getattr(model_config, "architectures", [])
    architecture: str = architectures[0]
    architecture_lower = architecture.lower()

    for i in ["qwen3_5", "qwen3next"]:
        if i in architecture_lower:
            return "qwen3_coder"

    if "qwen" in architecture_lower:
        return "qwen2_5"
    if "minimaxm2" in architecture_lower:
        return "minimax_m2 "
    return "qwen2_5"


class PoolingModel:
    def __init__(self, model_path: str):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
        architectures = getattr(model_config, "architectures", [])
        self.model = None
        self._pooling = None
        if "JinaForRanking" in architectures:
            self.model = AutoModel.from_pretrained(
                model_path,
                dtype="auto",
                trust_remote_code=True,
            )
            self.model.eval()
            self.model.to(device)  # Move model to device

            def pooling_(self, query: str, documents: list):
                results = self.model.rerank(query, documents)
                embedding = [[i["relevance_score"]] for i in results]
                ret = {}
                ret["embedding"] = embedding
                ret["token_num"] = 0
                return ret

            self._pooling = pooling_
        elif "JinaVLForRanking" in architectures:
            self.model = AutoModel.from_pretrained(
                model_path,
                torch_dtype="auto",
                trust_remote_code=True,
                # attn_implementation="flash_attention_2",
            )
            self.model.to(device)
            self.model.eval()
            logger.warning("model_type: JinaVLForRanking")

            def pooling_(self, query: str, documents: list):
                texts = documents
                sentence_pairs = [[query, inp] for inp in texts]
                query_type = doc_type = "text"

                if (
                    query.startswith("http://")
                    or query.startswith("https://")
                    or is_base64_image(query)
                ):
                    query_type = "image"
                if (
                    texts
                    and texts[0]
                    and (
                        texts[0].startswith("http://")
                        or texts[0].startswith("https://")
                        or is_base64_image(texts[0])
                    )
                ):
                    doc_type = "image"
                scores = self.model.compute_score(
                    sentence_pairs,
                    max_length=1024 * 2,
                    query_type=query_type,
                    doc_type=doc_type,
                )
                if isinstance(scores, float):
                    scores = [scores]
                embedding = [[float(score)] for score in scores]
                ret = {}
                ret["embedding"] = embedding
                ret["token_num"] = 0
                return ret

            self._pooling = pooling_
        else:
            mode = get_embedding_mode(model_path=model_path)
            if "embedding" == mode:
                self.model = sentence_transformers.SentenceTransformer(model_path)
                logger.warning("正在使用 embedding 模型...")
                encode_kwargs = {"normalize_embeddings": True, "batch_size": 64}

                def pooling_(self, query: str, documents: list = None):
                    texts = documents
                    outputs = self.model.tokenize(texts)
                    token_num = outputs["input_ids"].size(0) * outputs[
                        "input_ids"
                    ].size(1)
                    texts = list(map(lambda x: x.replace("\n", " "), texts))
                    embedding = self.model.encode(texts, **encode_kwargs).tolist()
                    ret = {}
                    ret["embedding"] = embedding
                    ret["token_num"] = token_num
                    return ret

                self._pooling = pooling_

            elif "rerank" == mode:
                self.model = sentence_transformers.CrossEncoder(model_name=model_path)
                logger.warning("正在使用 rerank 模型...")

                def pooling_(self, query: str, documents: list):
                    sentence_pairs = [[query, doc] for doc in documents]
                    scores = self.model.predict(sentence_pairs)
                    embedding = [[float(score)] for score in scores]
                    ret = {}
                    ret["embedding"] = embedding
                    ret["token_num"] = 0  # Rerank token num not typically calculated
                    return ret

                self._pooling = pooling_

            else:
                raise Exception(f"不支持的类型 mode: {mode}")

    def pooling(self, query, documents):
        if self._pooling is None:
            raise Exception("Model is not initialized or mode is not supported.")
        return self._pooling(self, query, documents)


def patch():
    class _HfFolder:
        pass

    import huggingface_hub

    huggingface_hub.HfFolder = _HfFolder
    logger.warning("patch huggingface_hub.HfFolder 成功!")


def get_embedding_mode(model_path: str):
    """获取模型的类型"""
    task_type = os.environ.get("task_type", "auto")
    if task_type == "embedding":
        return "embedding"
    elif task_type == "reranker":
        return "rerank"
    elif task_type == "classify":
        return "classify"

    model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    model_type_text = getattr(
        getattr(model_config, "text_config", {}), "model_type", None
    )
    logger.warning(f"model_type: {model_type_text}")

    model_type = model_type_text
    # --------- 在这里进行大过滤 ---------
    from infinity_emb import EngineArgs

    from infinity_emb.inference.select_model import get_engine_type_from_config

    engine_args = EngineArgs(
        model_name_or_path=model_path,
        engine="torch",
        embedding_dtype="float32",
        dtype="float32",
        bettertransformer=True,
    )
    engine_type = get_engine_type_from_config(engine_args)
    engine_type_str = str(engine_type)

    if "EmbedderEngine" in engine_type_str:
        return "embedding"
    elif "RerankEngine" in engine_type_str:
        return "rerank"
    elif "ImageEmbedEngine" in engine_type_str:
        return model_type or "image"
    elif "PredictEngine" in engine_type_str:
        return "classify"


if __name__ == "__main__":
    # 示例用法
    r = get_embedding_mode("/home/dev/model/jinaai/jina-reranker-v3/")
    print(r)


================================================
FILE: gpt_server/model_worker/voxcpm_tts.py
================================================
import os
from typing import List
from loguru import logger
import numpy as np
from gpt_server.model_handler.pitch import pitch_flashtts

pitch_flashtts()
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
from flashtts.server.utils.audio_writer import StreamingAudioWriter
import soundfile as sf
from voxcpm import VoxCPM

root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))


class VoxCPMTTSWorker(ModelWorkerBase):
    def __init__(
        self,
        controller_addr: str,
        worker_addr: str,
        worker_id: str,
        model_path: str,
        model_names: List[str],
        limit_worker_concurrency: int,
        conv_template: str = None,  # type: ignore
    ):
        super().__init__(
            controller_addr,
            worker_addr,
            worker_id,
            model_path,
            model_names,
            limit_worker_concurrency,
            conv_template,
            model_type="tts",
        )
        self.model = VoxCPM.from_pretrained(model_path)
        logger.warning(f"模型:{model_names[0]}")

    # 这个是模型主要的方法
    async def generate_voice_stream(self, params):
        if self.engine.engine_name != "spark":
            raise ValueError("仅Spark-TTS支持`generate_voice_stream`功能.")
        async for chunk_data in self.stream_async(params=params):
            yield chunk_data

    async def stream_async(self, params):
        text = params["text"]
        voice = params.get("voice", "新闻联播女声")
        response_format = params["response_format"]
        speed = params["speed"]
        pitch = params["pitch"]
        sample_rate = 16 * 1000
        audio_writer = StreamingAudioWriter(
            format=response_format, sample_rate=sample_rate
        )
        generator = None
        wav = self.model.generate(
            text=text,
            prompt_wav_path=None,  # optional: path to a prompt speech for voice cloning
            prompt_text=None,  # optional: reference text
            cfg_value=2.0,  # LM guidance on LocDiT, higher for better adherence to the prompt, but maybe worse
            inference_timesteps=10,  # LocDiT inference timesteps, higher for better result, lower for fast speed
            normalize=True,  # enable external TN tool
            denoise=True,  # enable external Denoise tool
            retry_badcase=True,  # enable retrying mode for some bad cases (unstoppable)
            retry_badcase_max_times=3,  # maximum retrying times
            retry_badcase_ratio_threshold=6.0,  # maximum length restriction for bad case detection (simple but effective), it could be adjusted for slow pace speech
        )

        # 分块处理(每块1024个样本)
        chunk_size = 1024
        for i in range(0, len(wav), chunk_size):
            chunk = wav[i : i + chunk_size]
            yield audio_writer.write_chunk(chunk.astype(np.float32))
        # 最终块处理
        end_chunk_data = audio_writer.write_chunk(finalize=True)
        yield end_chunk_data
     
        logger.debug(f"end_chunk_data 长度:{len(end_chunk_data)}")


if __name__ == "__main__":
    VoxCPMTTSWorker.run()


================================================
FILE: gpt_server/model_worker/wan.py
================================================
import asyncio

import io
import os
from typing import List
import uuid
from loguru import logger
import shortuuid
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
from gpt_server.model_worker.utils import pil_to_base64
from gpt_server.utils import STATIC_DIR
import torch
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers.utils import export_to_video

root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))


class WanWorker(ModelWorkerBase):
    def __init__(
        self,
        controller_addr: str,
        worker_addr: str,
        worker_id: str,
        model_path: str,
        model_names: List[str],
        limit_worker_concurrency: int,
        conv_template: str = None,  # type: ignore
    ):
        super().__init__(
            controller_addr,
            worker_addr,
            worker_id,
            model_path,
            model_names,
            limit_worker_concurrency,
            conv_template,
            model_type="image",
        )
        backend = os.environ["backend"]
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        vae = AutoencoderKLWan.from_pretrained(
            model_path, subfolder="vae", torch_dtype=torch.float32
        )
        self.pipe = WanPipeline.from_pretrained(
            model_path, vae=vae, torch_dtype=torch.bfloat16
        ).to(self.device)
        logger.warning(f"模型:{model_names[0]}")

    async def get_image_output(self, params):
        prompt = params["prompt"]
        negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
        output = self.pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            height=480,
            width=832,
            num_frames=81,
            guidance_scale=5.0,
        ).frames[0]

        # 生成唯一文件名(避免冲突)
        file_name = str(uuid.uuid4()) + ".mp4"
        save_path = STATIC_DIR / file_name
        export_to_video(output, save_path, fps=15)
        WORKER_PORT = os.environ["WORKER_PORT"]
        WORKER_HOST = os.environ["WORKER_HOST"]
        url = f"http://{WORKER_HOST}:{WORKER_PORT}/static/{file_name}"
        result = {
            "created": shortuuid.random(),
            "data": [{"url": url}],
            "usage": {
                "total_tokens": 0,
                "input_tokens": 0,
                "output_tokens": 0,
                "input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
            },
        }
        return result


if __name__ == "__main__":
    WanWorker.run()


================================================
FILE: gpt_server/model_worker/z_image.py
================================================
import asyncio
import os
from typing import List
import uuid
from loguru import logger
import shortuuid
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
from gpt_server.model_worker.utils import pil_to_base64
import torch
from diffusers import ZImagePipeline
from gpt_server.utils import STATIC_DIR

root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))

aspect_ratios = {
    "1:1": (1328, 1328),
    "16:9": (1664, 928),
    "9:16": (928, 1664),
    "4:3": (1472, 1140),
    "3:4": (1140, 1472),
    "3:2": (1584, 1056),
    "2:3": (1056, 1584),
}

width, height = aspect_ratios["16:9"]
import re


def contains_chinese(text):
    pattern = re.compile(r"[\u4e00-\u9fff]")
    return bool(pattern.search(text))


class ZImageWorker(ModelWorkerBase):
    def __init__(
        self,
        controller_addr: str,
        worker_addr: str,
        worker_id: str,
        model_path: str,
        model_names: List[str],
        limit_worker_concurrency: int,
        conv_template: str = None, 
Download .txt
gitextract_z12pzj5x/

├── .dockerignore
├── .github/
│   └── workflows/
│       └── docker-image.yml
├── .gitignore
├── .python-version
├── Dockerfile
├── Dockerfile.copy
├── LICENSE
├── MANIFEST.in
├── README.md
├── docker-compose-bash.yaml
├── docker-compose.yml
├── gpt_server/
│   ├── __init__.py
│   ├── cli.py
│   ├── database/
│   │   └── models/
│   │       └── process_manager.py
│   ├── model_backend/
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── hf_backend.py
│   │   ├── lmdeploy_backend.py
│   │   ├── sglang_backend.py
│   │   ├── utils.py
│   │   └── vllm_backend.py
│   ├── model_handler/
│   │   ├── __init__.py
│   │   ├── chat_template/
│   │   │   ├── get_chat_template.py
│   │   │   ├── qwen3.jinja
│   │   │   ├── qwen3_zh.jinja
│   │   │   └── qwen3vl.jinja
│   │   ├── pitch.py
│   │   ├── reasoning_parser.py
│   │   ├── tool_parser.py
│   │   └── utils.py
│   ├── model_worker/
│   │   ├── __init__.py
│   │   ├── auto.py
│   │   ├── base/
│   │   │   ├── __init__.py
│   │   │   ├── base_model_worker.py
│   │   │   └── model_worker_base.py
│   │   ├── embedding_infinity.py
│   │   ├── embedding_sentence_transformers.py
│   │   ├── embedding_v2.py
│   │   ├── embedding_vllm.py
│   │   ├── flux.py
│   │   ├── funasr.py
│   │   ├── qwen_image.py
│   │   ├── qwen_image_edit.py
│   │   ├── spark_tts.py
│   │   ├── utils.py
│   │   ├── voxcpm_tts.py
│   │   ├── wan.py
│   │   └── z_image.py
│   ├── openai_api_protocol/
│   │   ├── __init__.py
│   │   └── custom_api_protocol.py
│   ├── script/
│   │   ├── __init__.py
│   │   ├── config_example.yaml
│   │   ├── start.sh
│   │   └── stop.sh
│   ├── serving/
│   │   ├── __init__.py
│   │   ├── chat_ui.py
│   │   ├── controller.py
│   │   ├── controller_v2.py
│   │   ├── main.py
│   │   ├── openai_api_server.py
│   │   └── server_ui.py
│   ├── settings.py
│   ├── utils.py
│   └── version.py
├── pyproject.toml
├── setup.py
└── tests/
    ├── download_model.py
    ├── responses_api/
    │   ├── test_openai_responses.py
    │   ├── test_openai_responses_response_format.py
    │   ├── test_openai_responses_tool_calling.py
    │   └── test_response_vl_chat.py
    ├── sglang/
    │   └── models.py
    ├── test_chat_template.py
    ├── test_embedding_dynamic_batch.py
    ├── test_image_edit.py
    ├── test_image_gen.py
    ├── test_mteb.py
    ├── test_needle_haystack.py
    ├── test_openai_chat.py
    ├── test_openai_completion.py
    ├── test_openai_completion_response_format.py
    ├── test_openai_completion_tool_calling.py
    ├── test_openai_embedding.py
    ├── test_openai_embedding_vl.py
    ├── test_openai_moderation.py
    ├── test_openai_rerank.py
    ├── test_openai_transcriptions.py
    ├── test_openai_tts_stream.py
    ├── test_openai_vl_chat.py
    ├── test_perf.py
    ├── test_rerank.py
    └── vllm/
        ├── embedding.py
        └── models.py
Download .txt
SYMBOL INDEX (370 symbols across 52 files)

FILE: gpt_server/cli.py
  function ui (line 13) | def ui(
  function main (line 25) | def main():

FILE: gpt_server/database/models/process_manager.py
  class ProcessRecord (line 12) | class ProcessRecord(SQLModel, table=True):
  class ProcessManager (line 24) | class ProcessManager:
    method __init__ (line 25) | def __init__(self, write_db: bool = False, db_url: str = "sqlite:///pr...
    method add_process (line 42) | def add_process(
    method start_all (line 64) | def start_all(self):
    method join_all (line 87) | def join_all(self):

FILE: gpt_server/model_backend/base.py
  class ModelBackend (line 5) | class ModelBackend(ABC):
    method stream_chat (line 7) | def stream_chat(self, params: Dict[str, Any]):
    method shutdown (line 10) | def shutdown(self):

FILE: gpt_server/model_backend/hf_backend.py
  class NoneContextManager (line 22) | class NoneContextManager:
    method __enter__ (line 23) | def __enter__(self):
    method __exit__ (line 26) | def __exit__(self, exc_type, exc_val, exc_tb):
  class HFBackend (line 30) | class HFBackend(ModelBackend):
    method __init__ (line 31) | def __init__(self, tokenizer: PreTrainedTokenizer, model: torch.nn.Mod...
    method shutdown (line 57) | def shutdown(self):
    method stream_chat (line 60) | async def stream_chat(self, params: Dict[str, Any]):

FILE: gpt_server/model_backend/lmdeploy_backend.py
  class CustomRequestLogger (line 40) | class CustomRequestLogger(RequestLogger):
    method log_prompt (line 41) | def log_prompt(self, session_id: int, prompt: str) -> None:
    method log_inputs (line 47) | def log_inputs(
  class LMDeployBackend (line 70) | class LMDeployBackend(ModelBackend):
    method __init__ (line 71) | def __init__(self, model_path, tokenizer: PreTrainedTokenizer) -> None:
    method shutdown (line 106) | def shutdown(self):
    method stream_chat (line 109) | async def stream_chat(self, params: Dict[str, Any]) -> AsyncGenerator:

FILE: gpt_server/model_backend/sglang_backend.py
  class CustomOpenAIServingResponses (line 29) | class CustomOpenAIServingResponses(OpenAIServingResponses):
    method _process_messages (line 30) | def _process_messages(self, request, is_multimodal):
  class CustomOpenAIServingChat (line 39) | class CustomOpenAIServingChat(OpenAIServingChat):
    method _process_messages (line 40) | def _process_messages(self, request, is_multimodal):
  class SGLangBackend (line 49) | class SGLangBackend(ModelBackend):
    method __init__ (line 50) | def __init__(self, model_path, tokenizer: PreTrainedTokenizer) -> None:
    method shutdown (line 91) | def shutdown(self):
    method stream_chat (line 94) | async def stream_chat(self, params: Dict[str, Any]) -> AsyncGenerator:

FILE: gpt_server/model_backend/utils.py
  class XgrammarLogitsProcessor (line 15) | class XgrammarLogitsProcessor(LogitsProcessor):
    method __init__ (line 16) | def __init__(self, tokenizer: PreTrainedTokenizerBase):
    method get_json_grammar_processor (line 21) | def get_json_grammar_processor(self):
    method get_json_schema_processor (line 26) | def get_json_schema_processor(self, schema: Union[str, Type[BaseModel]]):
    method __call__ (line 31) | def __call__(
  class InvalidScoreLogitsProcessor (line 37) | class InvalidScoreLogitsProcessor(LogitsProcessor):
    method __call__ (line 38) | def __call__(
  class StopAtSpecificTokenCriteria (line 47) | class StopAtSpecificTokenCriteria(StoppingCriteria):
    method __init__ (line 52) | def __init__(self, token_id_list: List[int] = None):
    method __call__ (line 60) | def __call__(

FILE: gpt_server/model_backend/vllm_backend.py
  class CustomOpenAIServingResponses (line 24) | class CustomOpenAIServingResponses(OpenAIServingResponses):
    method _preprocess_chat (line 25) | async def _preprocess_chat(self, *args, **kwargs):
  class CustomOpenAIServingChat (line 36) | class CustomOpenAIServingChat(OpenAIServingChat):
    method render_chat_request (line 37) | async def render_chat_request(self, request):
  class VllmBackend (line 47) | class VllmBackend(ModelBackend):
    method __init__ (line 48) | def __init__(self, model_path, tokenizer: PreTrainedTokenizer) -> None:
    method shutdown (line 126) | def shutdown(self):
    method stream_chat (line 130) | async def stream_chat(self, params: Dict[str, Any]) -> AsyncGenerator:

FILE: gpt_server/model_handler/chat_template/get_chat_template.py
  function get_chat_template (line 7) | def get_chat_template(model_name: str = "", lang: Literal["en", "zh"] = ...

FILE: gpt_server/model_handler/pitch.py
  class VllmGenerator_ (line 7) | class VllmGenerator_(VllmGenerator):
    method __init__ (line 8) | def __init__(
  function pitch_flashtts (line 41) | def pitch_flashtts():

FILE: gpt_server/model_handler/reasoning_parser.py
  class DeepSeekR1ReasoningParser (line 15) | class DeepSeekR1ReasoningParser(ReasoningParser):
    method __init__ (line 22) | def __init__(self, tokenizer: object):
    method extract_reasoning_content_streaming (line 45) | def extract_reasoning_content_streaming(
    method extract_reasoning_content (line 126) | def extract_reasoning_content(

FILE: gpt_server/model_handler/tool_parser.py
  class ToolCall (line 16) | class ToolCall(BaseModel):
  class ExtractedToolCallInformation (line 25) | class ExtractedToolCallInformation(BaseModel):
  class Qwen2d5ToolParser (line 37) | class Qwen2d5ToolParser(ToolParser):
    method __init__ (line 38) | def __init__(self, tokenizer: object):
    method extract_tool_calls (line 45) | def extract_tool_calls(
  function tool_parser (line 122) | def tool_parser(full_text: str, tool_parser_: ToolParser, tools, ret):
  class ToolCallStreamProcessor (line 165) | class ToolCallStreamProcessor:
    method __init__ (line 170) | def __init__(self):
    method process_chunk (line 174) | def process_chunk(self, tool_calls_data: List[Dict]) -> Optional[List[...
    method get_completed_tool_calls (line 211) | def get_completed_tool_calls(self) -> Optional[List[Dict]]:
    method reset (line 245) | def reset(self):

FILE: gpt_server/model_worker/__init__.py
  function patch_infinity_embedder (line 8) | def patch_infinity_embedder():
  function patch_infinity_crossencoder (line 37) | def patch_infinity_crossencoder():

FILE: gpt_server/model_worker/auto.py
  class AutoWorker (line 16) | class AutoWorker(ModelWorkerBase):
    method __init__ (line 17) | def __init__(
    method generate_stream_gate (line 54) | async def generate_stream_gate(self, params):

FILE: gpt_server/model_worker/base/base_model_worker.py
  function build_logger (line 13) | def build_logger():
  function heart_beat_worker (line 25) | def heart_beat_worker(obj: "BaseModelWorker"):
  class BaseModelWorker (line 31) | class BaseModelWorker:
    method __init__ (line 32) | def __init__(
    method make_conv_template (line 67) | def make_conv_template(
    method init_heart_beat (line 84) | def init_heart_beat(self):
    method register_to_controller (line 93) | def register_to_controller(self):
    method send_heart_beat (line 106) | def send_heart_beat(self):
    method get_queue_length (line 129) | def get_queue_length(self):
    method get_status (line 143) | def get_status(self):
    method count_token (line 150) | def count_token(self, params):
    method get_conv_template (line 165) | def get_conv_template(self):
    method generate_stream_gate (line 168) | def generate_stream_gate(self, params):
    method generate_gate (line 171) | def generate_gate(self, params):
    method get_embeddings (line 174) | def get_embeddings(self, params):
    method classify (line 177) | def classify(self, params):
    method transcription (line 180) | def transcription(self, params):
    method generate_voice_stream (line 183) | def generate_voice_stream(self, params):
    method get_image_output (line 186) | def get_image_output(self, params):

FILE: gpt_server/model_worker/base/model_worker_base.py
  function get_context_length_ (line 34) | def get_context_length_(config):
  function cleanup_static_files (line 52) | async def cleanup_static_files():
  function run_scheduler (line 62) | async def run_scheduler():
  function pop_matching_tool (line 69) | def pop_matching_tool(tools, tool_choice):
  class ModelWorkerBase (line 82) | class ModelWorkerBase(BaseModelWorker, ABC):
    method __init__ (line 83) | def __init__(
    method preprocess_params (line 141) | def preprocess_params(self, params: dict) -> dict:
    method get_context_length (line 176) | def get_context_length(
    method get_model_class (line 184) | def get_model_class(self):
    method load_model_tokenizer (line 199) | def load_model_tokenizer(self, model_path):
    method generate_gate (line 247) | async def generate_gate(self, params):
    method get_worker (line 272) | def get_worker(
    method run (line 294) | def run(cls):
  function release_worker_semaphore (line 426) | def release_worker_semaphore():
  function acquire_worker_semaphore (line 430) | def acquire_worker_semaphore():
  function create_background_tasks (line 436) | def create_background_tasks(request_id):
  function gen_request_id (line 446) | def gen_request_id():
  function api_generate_stream (line 453) | async def api_generate_stream(request: Request):
  function api_generate_stream (line 468) | async def api_generate_stream(request: Request):
  function api_generate (line 500) | async def api_generate(request: Request):
  function api_get_status (line 517) | async def api_get_status(request: Request):
  function api_count_token (line 522) | async def api_count_token(request: Request):
  function api_get_conv (line 528) | async def api_get_conv(request: Request):
  function api_model_details (line 533) | async def api_model_details(request: Request):
  function api_get_embeddings (line 538) | async def api_get_embeddings(request: Request):
  function api_get_embeddings (line 548) | async def api_get_embeddings(request: Request):
  function api_get_classify (line 558) | async def api_get_classify(request: Request):
  function api_get_transcription (line 568) | async def api_get_transcription(request: Request):

FILE: gpt_server/model_worker/embedding_infinity.py
  class EmbeddingWorker (line 24) | class EmbeddingWorker(ModelWorkerBase):
    method __init__ (line 25) | def __init__(
    method astart (line 70) | async def astart(self):
    method get_embeddings (line 73) | async def get_embeddings(self, params):
    method classify (line 115) | async def classify(self, params):

FILE: gpt_server/model_worker/embedding_sentence_transformers.py
  class EmbeddingWorker (line 11) | class EmbeddingWorker(ModelWorkerBase):
    method __init__ (line 12) | def __init__(
    method get_embeddings (line 40) | async def get_embeddings(self, params):

FILE: gpt_server/model_worker/embedding_v2.py
  class EmbeddingWorker (line 10) | class EmbeddingWorker(ModelWorkerBase):
    method __init__ (line 11) | def __init__(
    method warm_up (line 67) | def warm_up(self):
    method batch_processor (line 74) | async def batch_processor(self):
    method add_request (line 132) | async def add_request(self, params: dict, future: asyncio.Future):
    method aembed (line 136) | async def aembed(self, params: dict, future: asyncio.Future):
    method rerank (line 139) | async def rerank(self, params: dict, future: asyncio.Future):
    method get_embeddings (line 142) | async def get_embeddings(self, params):

FILE: gpt_server/model_worker/embedding_vllm.py
  function template_format (line 24) | def template_format(queries: List[str], documents: List[str]):
  class EmbeddingWorker (line 42) | class EmbeddingWorker(ModelWorkerBase):
    method __init__ (line 43) | def __init__(
    method get_embeddings (line 82) | async def get_embeddings(self, params):

FILE: gpt_server/model_worker/flux.py
  class FluxWorker (line 18) | class FluxWorker(ModelWorkerBase):
    method __init__ (line 19) | def __init__(
    method get_image_output (line 47) | async def get_image_output(self, params):

FILE: gpt_server/model_worker/funasr.py
  class FunASRWorker (line 11) | class FunASRWorker(ModelWorkerBase):
    method __init__ (line 12) | def __init__(
    method transcription (line 48) | async def transcription(self, params):

FILE: gpt_server/model_worker/qwen_image.py
  function contains_chinese (line 34) | def contains_chinese(text):
  class QwenImageWorker (line 39) | class QwenImageWorker(ModelWorkerBase):
    method __init__ (line 40) | def __init__(
    method get_image_output (line 68) | async def get_image_output(self, params):

FILE: gpt_server/model_worker/qwen_image_edit.py
  class QwenImageEditWorker (line 21) | class QwenImageEditWorker(ModelWorkerBase):
    method __init__ (line 22) | def __init__(
    method get_image_output (line 49) | async def get_image_output(self, params):

FILE: gpt_server/model_worker/spark_tts.py
  class SparkTTSWorker (line 18) | class SparkTTSWorker(ModelWorkerBase):
    method __init__ (line 19) | def __init__(
    method generate_voice_stream (line 66) | async def generate_voice_stream(self, params):
    method stream_async (line 72) | async def stream_async(self, params):

FILE: gpt_server/model_worker/utils.py
  function is_base64_image (line 15) | def is_base64_image(data_string):
  function pil_to_base64 (line 21) | def pil_to_base64(pil_img: Image.Image, format: str = "PNG"):
  function _extract_base64 (line 27) | def _extract_base64(data_url: str):
  function _get_bytes_from_url (line 32) | async def _get_bytes_from_url(url: str) -> bytes:
  function bytesio2image (line 40) | def bytesio2image(bytes_io: io.BytesIO) -> Image.Image:
  function bytes2image (line 44) | def bytes2image(bytes_: bytes) -> Image.Image:
  function load_base64_or_url (line 49) | async def load_base64_or_url(base64_or_url) -> io.BytesIO:
  function guess_tool_parser_by_model (line 70) | def guess_tool_parser_by_model(model_path: str) -> str:
  class PoolingModel (line 88) | class PoolingModel:
    method __init__ (line 89) | def __init__(self, model_path: str):
    method pooling (line 200) | def pooling(self, query, documents):
  function patch (line 206) | def patch():
  function get_embedding_mode (line 216) | def get_embedding_mode(model_path: str):

FILE: gpt_server/model_worker/voxcpm_tts.py
  class VoxCPMTTSWorker (line 16) | class VoxCPMTTSWorker(ModelWorkerBase):
    method __init__ (line 17) | def __init__(
    method generate_voice_stream (line 41) | async def generate_voice_stream(self, params):
    method stream_async (line 47) | async def stream_async(self, params):

FILE: gpt_server/model_worker/wan.py
  class WanWorker (line 19) | class WanWorker(ModelWorkerBase):
    method __init__ (line 20) | def __init__(
    method get_image_output (line 50) | async def get_image_output(self, params):

FILE: gpt_server/model_worker/z_image.py
  function contains_chinese (line 29) | def contains_chinese(text):
  class ZImageWorker (line 34) | class ZImageWorker(ModelWorkerBase):
    method __init__ (line 35) | def __init__(
    method get_image_output (line 63) | async def get_image_output(self, params):

FILE: gpt_server/openai_api_protocol/custom_api_protocol.py
  class UsageInfo (line 68) | class UsageInfo(BaseModel):
  class ErrorInfo (line 77) | class ErrorInfo(BaseModel):
  class ErrorResponseV2 (line 84) | class ErrorResponseV2(BaseModel):
  class InputTokensDetails (line 88) | class InputTokensDetails(BaseModel):
  class OutputTokensDetails (line 94) | class OutputTokensDetails(BaseModel):
  class ResponseUsage (line 100) | class ResponseUsage(BaseModel):
  class ResponseReasoningParam (line 108) | class ResponseReasoningParam(BaseModel):
  class RequestResponseMetadata (line 117) | class RequestResponseMetadata(BaseModel):
  class ResponsesRequest (line 122) | class ResponsesRequest(BaseModel):
  class ResponsesResponse (line 183) | class ResponsesResponse(BaseModel):
    method from_request (line 219) | def from_request(
  class ImagesGenRequest (line 262) | class ImagesGenRequest(BaseModel):
  class OpenAISpeechRequest (line 281) | class OpenAISpeechRequest(BaseModel):
  class SpeechRequest (line 313) | class SpeechRequest(BaseModel):
  class ModerationsRequest (line 337) | class ModerationsRequest(BaseModel):
  class RerankRequest (line 343) | class RerankRequest(BaseModel):
  class EmbeddingsResponse (line 352) | class EmbeddingsResponse(BaseModel):
  class ModelPermission (line 359) | class ModelPermission(BaseModel):
  class CustomModelCard (line 374) | class CustomModelCard(BaseModel):
  class ModelList (line 384) | class ModelList(BaseModel):
  class CustomEmbeddingsRequest (line 389) | class CustomEmbeddingsRequest(BaseModel):
  class CustomChatCompletionRequest (line 398) | class CustomChatCompletionRequest(BaseModel):
  class ChatMessage (line 423) | class ChatMessage(BaseModel):
  class CustomChatMessage (line 428) | class CustomChatMessage(ChatMessage):
  class CustomChatCompletionResponseChoice (line 433) | class CustomChatCompletionResponseChoice(BaseModel):
  class LogProbs (line 439) | class LogProbs(BaseModel):
  class CustomCompletionResponseChoice (line 446) | class CustomCompletionResponseChoice(BaseModel):
  class CustomChatCompletionResponse (line 456) | class CustomChatCompletionResponse(BaseModel):
  class CustomDeltaMessage (line 466) | class CustomDeltaMessage(BaseModel):
  class CustomChatCompletionResponseStreamChoice (line 473) | class CustomChatCompletionResponseStreamChoice(BaseModel):
  class CustomChatCompletionStreamResponse (line 479) | class CustomChatCompletionStreamResponse(BaseModel):
  class CompletionResponse (line 488) | class CompletionResponse(BaseModel):

FILE: gpt_server/serving/chat_ui.py
  function clear_chat_history (line 37) | def clear_chat_history():
  function init_chat_history (line 41) | def init_chat_history():
  function main (line 56) | def main():

FILE: gpt_server/serving/controller.py
  class DispatchMethod (line 32) | class DispatchMethod(Enum):
    method from_str (line 37) | def from_str(cls, name):
  class WorkerInfo (line 47) | class WorkerInfo:
  function heart_beat_controller (line 56) | def heart_beat_controller(controller):
  class Controller (line 62) | class Controller:
    method __init__ (line 63) | def __init__(self, dispatch_method: str):
    method register_worker (line 73) | def register_worker(
    method get_worker_status (line 102) | def get_worker_status(self, worker_name: str):
    method remove_worker (line 115) | def remove_worker(self, worker_name: str):
    method refresh_all_workers (line 118) | def refresh_all_workers(self):
    method list_models (line 128) | def list_models(self):
    method list_multimodal_models (line 136) | def list_multimodal_models(self):
    method list_language_models (line 145) | def list_language_models(self):
    method get_worker_address (line 207) | def get_worker_address(self, model_name: str):
    method receive_heart_beat (line 215) | def receive_heart_beat(self, worker_name: str, queue_length: int):
    method remove_stale_workers_by_expiration (line 225) | def remove_stale_workers_by_expiration(self):
    method handle_no_worker (line 235) | def handle_no_worker(self, params):
    method handle_worker_timeout (line 243) | def handle_worker_timeout(self, worker_address):
    method worker_api_get_status (line 253) | def worker_api_get_status(self):
    method worker_api_generate_stream (line 272) | def worker_api_generate_stream(self, params):
  function register_worker (line 295) | async def register_worker(request: Request):
  function refresh_all_workers (line 306) | async def refresh_all_workers():
  function list_models (line 311) | async def list_models():
  function list_multimodal_models (line 317) | async def list_multimodal_models():
  function list_language_models (line 323) | async def list_language_models():
  function get_worker_address (line 329) | async def get_worker_address(request: Request):
  function receive_heart_beat (line 336) | async def receive_heart_beat(request: Request):
  function worker_api_generate_stream (line 343) | async def worker_api_generate_stream(request: Request):
  function worker_api_get_status (line 350) | async def worker_api_get_status(request: Request):
  function worker_api_get_status (line 355) | async def worker_api_get_status(request: Request):
  function create_controller (line 359) | def create_controller():

FILE: gpt_server/serving/controller_v2.py
  class DispatchMethod (line 34) | class DispatchMethod(Enum):
    method from_str (line 39) | def from_str(cls, name):
  class Worker (line 50) | class Worker(SQLModel, table=True):
  function create_db_and_tables (line 72) | def create_db_and_tables():
  function heart_beat_controller (line 79) | def heart_beat_controller(controller: "Controller"):
  class Controller (line 86) | class Controller:
    method __init__ (line 87) | def __init__(self, dispatch_method: str, db_engine):
    method get_session (line 97) | def get_session(self):
    method register_worker (line 101) | def register_worker(
    method get_worker_status (line 147) | def get_worker_status(self, worker_addr: str):
    method remove_worker (line 161) | def remove_worker(self, worker_addr: str):
    method refresh_all_workers (line 174) | def refresh_all_workers(self):
    method list_models (line 196) | def list_models(self):
    method list_multimodal_models (line 207) | def list_multimodal_models(self):
    method list_language_models (line 217) | def list_language_models(self):
    method get_worker_address (line 227) | def get_worker_address(self, model_name: str):
    method receive_heart_beat (line 242) | def receive_heart_beat(self, worker_addr: str, queue_length: int):
    method remove_stale_workers_by_expiration (line 257) | def remove_stale_workers_by_expiration(self):
    method handle_no_worker (line 278) | def handle_no_worker(self, params):
    method handle_worker_timeout (line 287) | def handle_worker_timeout(self, worker_address):
  function register_worker (line 301) | async def register_worker(request: Request):
  function refresh_all_workers (line 312) | async def refresh_all_workers():
  function list_models (line 317) | async def list_models():
  function list_multimodal_models (line 323) | async def list_multimodal_models():
  function list_language_models (line 329) | async def list_language_models():
  function get_worker_address (line 335) | async def get_worker_address(request: Request):
  function receive_heart_beat (line 342) | async def receive_heart_beat(request: Request):
  function worker_api_get_status (line 350) | async def worker_api_get_status(request: Request):
  function create_controller (line 354) | def create_controller(db_engine_to_use):

FILE: gpt_server/serving/main.py
  function get_enabled_models (line 43) | def get_enabled_models(config):
  function main (line 57) | def main():

FILE: gpt_server/serving/openai_api_server.py
  function fetch_remote (line 64) | async def fetch_remote(url, pload=None, name=None):
  class AppSettings (line 88) | class AppSettings(BaseSettings):
    method split_api_keys (line 94) | def split_api_keys(cls, v):
    class Config (line 99) | class Config:
      method parse_env_var (line 102) | def parse_env_var(cls, field_name: str, raw_val: str):
  function timing_tasks (line 113) | async def timing_tasks():
  function lifespan (line 143) | async def lifespan(app: fastapi.FastAPI):
  function check_api_key (line 154) | async def check_api_key(
  function create_error_response (line 176) | def create_error_response(code: int, message: str) -> JSONResponse:
  function validation_exception_handler (line 183) | async def validation_exception_handler(request: Request, exc: RequestVal...
  function check_model (line 187) | def check_model(model: str) -> Optional[JSONResponse]:
  function process_input (line 199) | def process_input(model_name, inp):
  function create_openai_logprobs (line 223) | def create_openai_logprobs(logprob_dict):
  function _add_to_set (line 228) | def _add_to_set(s, new_stop):
  function get_gen_params (line 237) | def get_gen_params(
  class AddressManager (line 302) | class AddressManager:
    method __init__ (line 303) | def __init__(self):
    method get_address (line 307) | def get_address(self, model):
  function get_worker_address (line 329) | def get_worker_address(model_name: str) -> str:
  function get_conv (line 348) | async def get_conv(model_name: str, worker_addr: str):
  function show_available_models (line 370) | async def show_available_models():
  function get_model_address_map (line 408) | def get_model_address_map():
  function create_responses (line 418) | async def create_responses(request: ResponsesRequest):
  function create_chat_completion (line 460) | async def create_chat_completion(request: CustomChatCompletionRequest):
  function chat_completion_stream_generator (line 547) | async def chat_completion_stream_generator(
  function create_completion (line 603) | async def create_completion(request: CompletionRequest):
  function generate_completion_stream_generator (line 674) | async def generate_completion_stream_generator(
  function generate_completion_stream (line 733) | async def generate_completion_stream(payload: Dict[str, Any], worker_add...
  function generate_completion (line 754) | async def generate_completion(payload: Dict[str, Any], worker_addr: str):
  function get_images_edits (line 769) | async def get_images_edits(payload: Dict[str, Any]):
  function images_edits (line 780) | async def images_edits(
  function get_images_gen (line 812) | async def get_images_gen(payload: Dict[str, Any]):
  function images_generations (line 823) | async def images_generations(request: ImagesGenRequest):
  function generate_voice_stream (line 845) | async def generate_voice_stream(payload: Dict[str, Any], worker_addr: str):
  function speech (line 862) | async def speech(request: OpenAISpeechRequest):
  function get_transcriptions (line 908) | async def get_transcriptions(payload: Dict[str, Any]):
  function transcriptions (line 924) | async def transcriptions(file: UploadFile, model: str = Form()):
  function classify (line 952) | async def classify(request: ModerationsRequest):
  function rerank (line 995) | async def rerank(request: RerankRequest):
  function create_embeddings (line 1038) | async def create_embeddings(request: CustomEmbeddingsRequest, model_name...
  function get_classify (line 1085) | async def get_classify(payload: Dict[str, Any]):
  function get_embedding (line 1094) | async def get_embedding(payload: Dict[str, Any]):
  function count_tokens (line 1107) | async def count_tokens(request: APITokenCheckRequest):
  function create_openai_api_server (line 1141) | def create_openai_api_server():

FILE: gpt_server/serving/server_ui.py
  function get_process_num (line 26) | def get_process_num():
  function update_config (line 37) | def update_config(config: dict):
  function serve_args (line 58) | def serve_args():
  function controller_args (line 82) | def controller_args():
  function model_worker_args (line 101) | def model_worker_args():

FILE: gpt_server/settings.py
  class ModelConfig (line 4) | class ModelConfig(BaseSettings):
  function get_model_config (line 27) | def get_model_config() -> ModelConfig:

FILE: gpt_server/utils.py
  function _register (line 29) | def _register(group: str, proc: subprocess.Popen):
  function _kill_tree (line 33) | def _kill_tree(pid: int, timeout: int = 5):
  function _graceful_shutdown (line 57) | def _graceful_shutdown():
  function clear_flashinfer_cache (line 66) | def clear_flashinfer_cache():
  function delete_flash_attn (line 70) | def delete_flash_attn():
  function pre_processing (line 95) | def pre_processing():
  function signal_handler (line 108) | def signal_handler(signum, frame):
  function run_cmd (line 117) | def run_cmd(cmd: str, group: str = "worker") -> subprocess.Popen:
  function start_controller (line 126) | def start_controller(controller_host, controller_port, dispatch_method):
  function start_openai_server (line 136) | def start_openai_server(host, port, controller_address, api_keys=None):
  function start_api_server (line 148) | def start_api_server(config: dict):
  function get_model_types (line 179) | def get_model_types():
  function start_model_worker (line 197) | def start_model_worker(config: dict):
  function start_server (line 370) | def start_server(
  function delete_log (line 398) | def delete_log():
  function get_free_tcp_port (line 412) | def get_free_tcp_port():
  function is_port_in_use (line 421) | def is_port_in_use(port: int):
  function get_physical_ip (line 430) | def get_physical_ip():

FILE: gpt_server/version.py
  function parse_version_info (line 7) | def parse_version_info(version_str: str) -> Tuple:

FILE: setup.py
  function readme (line 9) | def readme():
  function get_version (line 15) | def get_version():

FILE: tests/download_model.py
  function model_download (line 10) | def model_download(model_id, local_dir="/data", hub_name="hf", repo_type...

FILE: tests/responses_api/test_openai_responses_response_format.py
  class Distance (line 28) | class Distance(BaseModel):

FILE: tests/responses_api/test_openai_responses_tool_calling.py
  function get_weather (line 6) | def get_weather(location: str, unit: str = "2") -> str:
  function main (line 35) | def main():

FILE: tests/sglang/models.py
  class CustomOpenAIServingChat (line 25) | class CustomOpenAIServingChat(OpenAIServingChat):
    method _process_messages (line 26) | def _process_messages(self, request, is_multimodal):
  function main (line 33) | async def main():

FILE: tests/test_embedding_dynamic_batch.py
  function f (line 6) | async def f():
  function main (line 17) | async def main():

FILE: tests/test_openai_completion_response_format.py
  class Distance (line 28) | class Distance(BaseModel):

FILE: tests/test_openai_completion_tool_calling.py
  function get_weather (line 8) | def get_weather(location: str, unit: str = "celsius"):

FILE: tests/test_openai_embedding_vl.py
  function image_to_base64 (line 14) | def image_to_base64(image_path):

FILE: tests/test_openai_vl_chat.py
  function image_to_base64 (line 6) | def image_to_base64(image_path):

FILE: tests/test_rerank.py
  function rerank (line 7) | def rerank():

FILE: tests/vllm/embedding.py
  function main (line 19) | async def main():

FILE: tests/vllm/models.py
  class CustomOpenAIServingChat (line 18) | class CustomOpenAIServingChat(OpenAIServingChat):
    method render_chat_request (line 19) | async def render_chat_request(self, request):
  function main (line 51) | async def main():
Condensed preview — 93 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (390K chars).
[
  {
    "path": ".dockerignore",
    "chars": 78,
    "preview": "test/\n.vscode/\n.venv/\n__pycache__/\n*.log*\n*.egg-info\nlogs/\noutputs/\ndata/\n.env"
  },
  {
    "path": ".github/workflows/docker-image.yml",
    "chars": 3642,
    "preview": "name: Docker Image CI\n\non:\n  # release:\n  #   types: \n  #     - published  # 当发布新的 release 时触发\n  push:\n    branches:\n   "
  },
  {
    "path": ".gitignore",
    "chars": 102,
    "preview": ".vscode/\n__pycache__/\n*.log*\n*.egg-info\ntest/\nlogs/\noutputs/\ndata/\n.venv/\nconfig.yaml\n.env\n*_test.yaml"
  },
  {
    "path": ".python-version",
    "chars": 5,
    "preview": "3.11\n"
  },
  {
    "path": "Dockerfile",
    "chars": 674,
    "preview": "# FROM docker.1ms.run/506610466/cuda:12.2.2-runtime-ubuntu20.04-uv\nFROM 506610466/cuda:12.2.2-devel-ubuntu20.04-uv\n# 从基础"
  },
  {
    "path": "Dockerfile.copy",
    "chars": 109,
    "preview": "FROM docker.1ms.run/506610466/gpt_server:latest \n\nCOPY ./ /gpt_server\n\nWORKDIR /gpt_server\n\nCMD [\"/bin/bash\"]"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "MANIFEST.in",
    "chars": 32,
    "preview": "include gpt_server/script/*.yaml"
  },
  {
    "path": "README.md",
    "chars": 16252,
    "preview": "<div align=\"center\">\n\n<a href=\"https://github.com/shell-nlp/gpt_server\"><img src=\"assets/logo.png\" width=\"252\" height=\"1"
  },
  {
    "path": "docker-compose-bash.yaml",
    "chars": 998,
    "preview": "# 这容器的目的是为了方便直接在容器内使用项目的用户\nversion: '3.8'\nservices:\n  gpt_server_bash:\n    # ------ 从项目构建最新代码镜像 ------\n    # build:\n    "
  },
  {
    "path": "docker-compose.yml",
    "chars": 986,
    "preview": "version: '3'\nservices:\n  gpt_server:\n    # 构建\n    # 为什么每次构建更好?而不是直接使用 image: docker.1ms.run/506610466/gpt_server:latest\n"
  },
  {
    "path": "gpt_server/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "gpt_server/cli.py",
    "chars": 695,
    "preview": "import subprocess\nimport os\nimport typer\n\napp = typer.Typer()\nroot_dir = os.path.dirname(__file__)\nroot_dir = os.path.ab"
  },
  {
    "path": "gpt_server/database/models/process_manager.py",
    "chars": 4072,
    "preview": "\"\"\"暂时没有使用此代码\"\"\"\n\nfrom typing import List, Dict, Optional, Any\nfrom multiprocessing import Process\nfrom sqlmodel import S"
  },
  {
    "path": "gpt_server/model_backend/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "gpt_server/model_backend/base.py",
    "chars": 214,
    "preview": "from abc import ABC, abstractmethod\nfrom typing import Any, Dict\n\n\nclass ModelBackend(ABC):\n    @abstractmethod\n    def "
  },
  {
    "path": "gpt_server/model_backend/hf_backend.py",
    "chars": 7500,
    "preview": "from typing import Any, Dict\nimport torch\nimport json\nfrom peft import PeftModel\nfrom transformers import TextIteratorSt"
  },
  {
    "path": "gpt_server/model_backend/lmdeploy_backend.py",
    "chars": 8668,
    "preview": "import os\nimport sys\nfrom lmdeploy import (\n    GenerationConfig,\n    TurbomindEngineConfig,\n    PytorchEngineConfig,\n)\n"
  },
  {
    "path": "gpt_server/model_backend/sglang_backend.py",
    "chars": 10003,
    "preview": "import asyncio\nimport json\nfrom typing import Any, AsyncGenerator, Dict\n\nfrom loguru import logger\nfrom sglang.srt.entry"
  },
  {
    "path": "gpt_server/model_backend/utils.py",
    "chars": 2430,
    "preview": "from typing import List, Type, Union\nfrom pydantic import BaseModel\nfrom transformers.generation.logits_process import L"
  },
  {
    "path": "gpt_server/model_backend/vllm_backend.py",
    "chars": 12247,
    "preview": "from dataclasses import asdict\nimport json\nfrom typing import Any, AsyncGenerator, Dict\n\nfrom loguru import logger\nfrom "
  },
  {
    "path": "gpt_server/model_handler/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "gpt_server/model_handler/chat_template/get_chat_template.py",
    "chars": 709,
    "preview": "from pathlib import Path\nfrom typing import Literal\n\ncur_path = Path(__file__).parent\n\n\ndef get_chat_template(model_name"
  },
  {
    "path": "gpt_server/model_handler/chat_template/qwen3.jinja",
    "chars": 4235,
    "preview": "{%- if tools %}\n    {{- '<|im_start|>system\\n' }}\n    {%- if messages[0].role == 'system' %}\n        {{- messages[0].con"
  },
  {
    "path": "gpt_server/model_handler/chat_template/qwen3_zh.jinja",
    "chars": 4141,
    "preview": "{%- if tools %}\n    {{- '<|im_start|>system\\n' }}\n    {%- if messages[0].role == 'system' %}\n        {{- messages[0].con"
  },
  {
    "path": "gpt_server/model_handler/chat_template/qwen3vl.jinja",
    "chars": 5326,
    "preview": "{%- set image_count = namespace(value=0) %}\n{%- set video_count = namespace(value=0) %}\n{%- macro render_content(content"
  },
  {
    "path": "gpt_server/model_handler/pitch.py",
    "chars": 1286,
    "preview": "from typing import Optional\nfrom flashtts.llm.vllm_generator import VllmGenerator\nimport flashtts\nfrom loguru import log"
  },
  {
    "path": "gpt_server/model_handler/reasoning_parser.py",
    "chars": 7507,
    "preview": "# Copyright (c) OpenMMLab. All rights reserved.\n# modified from https://github.com/vllm-project/vllm/tree/v0.7.3/vllm/en"
  },
  {
    "path": "gpt_server/model_handler/tool_parser.py",
    "chars": 9750,
    "preview": "import json\nimport re\nfrom typing import List, Literal, Optional\n\nfrom loguru import logger\nfrom pydantic import BaseMod"
  },
  {
    "path": "gpt_server/model_handler/utils.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "gpt_server/model_worker/__init__.py",
    "chars": 2277,
    "preview": "from gpt_server.model_worker.utils import patch\nimport os\n\nos.environ[\"VLLM_WORKER_MULTIPROC_METHOD\"] = \"spawn\"\npatch()\n"
  },
  {
    "path": "gpt_server/model_worker/auto.py",
    "chars": 3302,
    "preview": "import json\nimport traceback\nfrom typing import List\n\nfrom fastchat.constants import ErrorCode, SERVER_ERROR_MSG\nfrom lo"
  },
  {
    "path": "gpt_server/model_worker/base/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "gpt_server/model_worker/base/base_model_worker.py",
    "chars": 5270,
    "preview": "import threading\nimport time\nfrom typing import List\n\nfrom fastapi import FastAPI, Request, BackgroundTasks\nfrom fastapi"
  },
  {
    "path": "gpt_server/model_worker/base/model_worker_base.py",
    "chars": 20326,
    "preview": "import asyncio\nfrom datetime import datetime\nfrom typing import List\nimport json\nimport sys\nimport shutil\nfrom abc impor"
  },
  {
    "path": "gpt_server/model_worker/embedding_infinity.py",
    "chars": 5254,
    "preview": "import os\nfrom typing import List\nimport asyncio\nfrom loguru import logger\n\nfrom infinity_emb import AsyncEngineArray, E"
  },
  {
    "path": "gpt_server/model_worker/embedding_sentence_transformers.py",
    "chars": 1340,
    "preview": "import os\nfrom typing import List\n\nfrom loguru import logger\nfrom gpt_server.model_worker.base.model_worker_base import "
  },
  {
    "path": "gpt_server/model_worker/embedding_v2.py",
    "chars": 5869,
    "preview": "import os\nfrom typing import List\nfrom gpt_server.model_worker.base.model_worker_base import ModelWorkerBase\nimport sent"
  },
  {
    "path": "gpt_server/model_worker/embedding_vllm.py",
    "chars": 4311,
    "preview": "import os\nfrom typing import List\nfrom loguru import logger\n\nfrom gpt_server.model_worker.base.model_worker_base import "
  },
  {
    "path": "gpt_server/model_worker/flux.py",
    "chars": 3083,
    "preview": "import asyncio\n\nimport io\nimport os\nfrom typing import List\nimport uuid\nfrom loguru import logger\nimport shortuuid\nfrom "
  },
  {
    "path": "gpt_server/model_worker/funasr.py",
    "chars": 2015,
    "preview": "import os\nfrom typing import List\nimport base64\nfrom loguru import logger\nfrom gpt_server.model_worker.base.model_worker"
  },
  {
    "path": "gpt_server/model_worker/qwen_image.py",
    "chars": 3895,
    "preview": "import asyncio\nimport os\nfrom typing import List\nimport uuid\nfrom loguru import logger\nimport shortuuid\nfrom gpt_server."
  },
  {
    "path": "gpt_server/model_worker/qwen_image_edit.py",
    "chars": 3546,
    "preview": "import asyncio\n\nimport os\nfrom typing import List\nimport uuid\nfrom loguru import logger\nimport shortuuid\nfrom gpt_server"
  },
  {
    "path": "gpt_server/model_worker/spark_tts.py",
    "chars": 3713,
    "preview": "import asyncio\n\nimport os\nfrom typing import List\nfrom loguru import logger\nfrom gpt_server.model_handler.pitch import p"
  },
  {
    "path": "gpt_server/model_worker/utils.py",
    "chars": 8989,
    "preview": "import httpx\nfrom loguru import logger\nfrom fastapi import HTTPException\nimport base64\nimport io\nimport os\nfrom PIL impo"
  },
  {
    "path": "gpt_server/model_worker/voxcpm_tts.py",
    "chars": 3101,
    "preview": "import os\nfrom typing import List\nfrom loguru import logger\nimport numpy as np\nfrom gpt_server.model_handler.pitch impor"
  },
  {
    "path": "gpt_server/model_worker/wan.py",
    "chars": 2907,
    "preview": "import asyncio\n\nimport io\nimport os\nfrom typing import List\nimport uuid\nfrom loguru import logger\nimport shortuuid\nfrom "
  },
  {
    "path": "gpt_server/model_worker/z_image.py",
    "chars": 3737,
    "preview": "import asyncio\nimport os\nfrom typing import List\nimport uuid\nfrom loguru import logger\nimport shortuuid\nfrom gpt_server."
  },
  {
    "path": "gpt_server/openai_api_protocol/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "gpt_server/openai_api_protocol/custom_api_protocol.py",
    "chars": 16663,
    "preview": "import time\nfrom typing import Any, Dict, List, Literal, Optional, TypeAlias, Union\nimport uuid\n\nfrom pydantic import Fi"
  },
  {
    "path": "gpt_server/script/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "gpt_server/script/config_example.yaml",
    "chars": 6636,
    "preview": "# 后台启动 nohup sh start.sh > gptserver.log &\n# openai_api_server\nserve_args:\n  # openai 服务的 host 和 port\n  enable: true\n  h"
  },
  {
    "path": "gpt_server/script/start.sh",
    "chars": 131,
    "preview": "#!/usr/bin/env bash\n\nscript_dir=$(cd $(dirname $0);pwd)\n\necho $(dirname $script_dir)\n\npython $(dirname $script_dir)/serv"
  },
  {
    "path": "gpt_server/script/stop.sh",
    "chars": 161,
    "preview": "#!/usr/bin/env bash\n\n# ps -ef | grep fastchat.serve | awk '{print $2}' |xargs -I{} kill -9 {}\n\nps -ef | grep gpt_server "
  },
  {
    "path": "gpt_server/serving/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "gpt_server/serving/chat_ui.py",
    "chars": 3126,
    "preview": "import streamlit as st\nfrom openai import OpenAI\nimport os\nimport sys\nimport yaml\n\nif \"config\" not in st.session_state:\n"
  },
  {
    "path": "gpt_server/serving/controller.py",
    "chars": 12347,
    "preview": "\"\"\"\nA controller manages distributed workers.\nIt sends worker addresses to clients.\n\"\"\"\n\nimport argparse\nimport asyncio\n"
  },
  {
    "path": "gpt_server/serving/controller_v2.py",
    "chars": 13411,
    "preview": "\"\"\"\nA controller manages distributed workers.\nIt sends worker addresses to clients.\nThis version is modified to use SQLM"
  },
  {
    "path": "gpt_server/serving/main.py",
    "chars": 2058,
    "preview": "import time\nimport yaml\nimport os\nimport sys\nimport ray\nfrom dotenv import load_dotenv\nfrom loguru import logger\nimport "
  },
  {
    "path": "gpt_server/serving/openai_api_server.py",
    "chars": 40552,
    "preview": "\"\"\"A server that provides OpenAI-compatible RESTful APIs. It supports:\n\n- Chat Completions. (Reference: https://platform"
  },
  {
    "path": "gpt_server/serving/server_ui.py",
    "chars": 13740,
    "preview": "import streamlit as st\nimport yaml\nimport os\nimport sys\nfrom loguru import logger\nfrom copy import deepcopy\nimport subpr"
  },
  {
    "path": "gpt_server/settings.py",
    "chars": 807,
    "preview": "from pydantic_settings import BaseSettings\n\n\nclass ModelConfig(BaseSettings):\n    model_name_or_path: str | None = None\n"
  },
  {
    "path": "gpt_server/utils.py",
    "chars": 16605,
    "preview": "import socket\nfrom typing import List, Optional\nimport os\nimport sys\nimport json\nimport subprocess\nfrom loguru import lo"
  },
  {
    "path": "gpt_server/version.py",
    "chars": 724,
    "preview": "from typing import Tuple\n\n__version__ = \"0.6.0\"\nshort_version = __version__\n\n\ndef parse_version_info(version_str: str) -"
  },
  {
    "path": "pyproject.toml",
    "chars": 2076,
    "preview": "[project]\nname = \"gpt_server\"\nversion = \"0.7.2\"\ndescription = \"gpt_server是一个用于生产级部署LLMs、Embedding、Reranker、ASR和TTS的开源框架。"
  },
  {
    "path": "setup.py",
    "chars": 826,
    "preview": "import os\nfrom setuptools import setup, find_packages\n\n\npwd = os.path.dirname(__file__)\nversion_file = \"gpt_server/versi"
  },
  {
    "path": "tests/download_model.py",
    "chars": 1237,
    "preview": "\"\"\"\n如果使用   hf 下载 则:\npip install -U huggingface_hub hf_transfer\n\n如果使用 modelscope 下载 则:\npip install modelscope\n\"\"\"\n\n\ndef m"
  },
  {
    "path": "tests/responses_api/test_openai_responses.py",
    "chars": 916,
    "preview": "from openai import OpenAI\n\n# 新版本 opnai\nclient = OpenAI(api_key=\"EMPTY\", base_url=\"http://localhost:8082/v1\")\n\nstream = T"
  },
  {
    "path": "tests/responses_api/test_openai_responses_response_format.py",
    "chars": 958,
    "preview": "from openai import OpenAI\nfrom pydantic import BaseModel, Field\n\n# 新版本 opnai\nclient = OpenAI(api_key=\"EMPTY\", base_url=\""
  },
  {
    "path": "tests/responses_api/test_openai_responses_tool_calling.py",
    "chars": 1648,
    "preview": "import json\n\nfrom openai import OpenAI\n\n\ndef get_weather(location: str, unit: str = \"2\") -> str:\n    \"\"\"\n    Get the cur"
  },
  {
    "path": "tests/responses_api/test_response_vl_chat.py",
    "chars": 590,
    "preview": "from openai import OpenAI\n\nclient = OpenAI(api_key=\"EMPTY\", base_url=\"http://localhost:8082/v1\")\n\nstream = True\nresponse"
  },
  {
    "path": "tests/sglang/models.py",
    "chars": 2453,
    "preview": "import asyncio\n\nimport os\nfrom sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat\nfrom sglang.srt.entry"
  },
  {
    "path": "tests/test_chat_template.py",
    "chars": 1796,
    "preview": "from transformers import AutoTokenizer\n\nurl = \"https://opencompass.oss-cn-shanghai.aliyuncs.com/image/compass-hub/botcha"
  },
  {
    "path": "tests/test_embedding_dynamic_batch.py",
    "chars": 1068,
    "preview": "import asyncio\nfrom openai import AsyncOpenAI\nimport time\n\n\nasync def f():\n    batch = 5\n    client = AsyncOpenAI(api_ke"
  },
  {
    "path": "tests/test_image_edit.py",
    "chars": 704,
    "preview": "import base64\nfrom pathlib import Path\nfrom openai import OpenAI\n\nclient = OpenAI(api_key=\"EMPTY\", base_url=\"http://loca"
  },
  {
    "path": "tests/test_image_gen.py",
    "chars": 726,
    "preview": "import base64\nfrom openai import OpenAI\n\nclient = OpenAI(api_key=\"EMPTY\", base_url=\"http://localhost:8082/v1\")\n# 两种响应方式\n"
  },
  {
    "path": "tests/test_mteb.py",
    "chars": 1253,
    "preview": "\"\"\"用于对 Embedding 模型进行评估的 MTEB 任务\n指标文档: https://evalscope.readthedocs.io/zh-cn/latest/user_guides/backend/rageval_backend"
  },
  {
    "path": "tests/test_needle_haystack.py",
    "chars": 1709,
    "preview": "\"\"\"大海捞针评测\"\"\"\n\nimport os\nfrom evalscope import TaskConfig, run_task\n\ntask_cfg = TaskConfig(\n    model=\"qwen\",\n    api_url"
  },
  {
    "path": "tests/test_openai_chat.py",
    "chars": 540,
    "preview": "from openai import OpenAI\n\n# 新版本 opnai\nclient = OpenAI(api_key=\"EMPTY\", base_url=\"http://localhost:8082/v1\")\n\nstream = T"
  },
  {
    "path": "tests/test_openai_completion.py",
    "chars": 417,
    "preview": "from openai import OpenAI\nimport time\nfrom rich import print\n\n# 新版本 opnai\nclient = OpenAI(api_key=\"EMPTY\", base_url=\"htt"
  },
  {
    "path": "tests/test_openai_completion_response_format.py",
    "chars": 915,
    "preview": "from openai import OpenAI\nfrom pydantic import BaseModel, Field\n\n# 新版本 opnai\nclient = OpenAI(api_key=\"EMPTY\", base_url=\""
  },
  {
    "path": "tests/test_openai_completion_tool_calling.py",
    "chars": 1394,
    "preview": "from openai import OpenAI\nimport json\n\n# 新版本 opnai\nclient = OpenAI(api_key=\"EMPTY\", base_url=\"http://localhost:8082/v1\")"
  },
  {
    "path": "tests/test_openai_embedding.py",
    "chars": 570,
    "preview": "from openai import OpenAI\nfrom rich import print\nimport numpy as np\n\n# 新版本 opnai\nclient = OpenAI(api_key=\"EMPTY\", base_u"
  },
  {
    "path": "tests/test_openai_embedding_vl.py",
    "chars": 811,
    "preview": "from openai import OpenAI\nfrom rich import print\nimport base64\n\n\n## 测试只对 文本嵌入\nclient = OpenAI(api_key=\"EMPTY\", base_url="
  },
  {
    "path": "tests/test_openai_moderation.py",
    "chars": 331,
    "preview": "from openai import OpenAI\nfrom rich import print\n\n# 新版本 opnai\nclient = OpenAI(api_key=\"EMPTY\", base_url=\"http://localhos"
  },
  {
    "path": "tests/test_openai_rerank.py",
    "chars": 277,
    "preview": "from openai import OpenAI\nfrom rich import print\n\n# 新版本 opnai\nclient = OpenAI(api_key=\"EMPTY\", base_url=\"http://localhos"
  },
  {
    "path": "tests/test_openai_transcriptions.py",
    "chars": 280,
    "preview": "from openai import OpenAI\n\nclient = OpenAI(api_key=\"EMPTY\", base_url=\"http://localhost:8082/v1\")\n\naudio_file = open(\"/ho"
  },
  {
    "path": "tests/test_openai_tts_stream.py",
    "chars": 1370,
    "preview": "import base64\nfrom pathlib import Path\nfrom openai import OpenAI\n\nspeech_file_path = Path(__file__).parent / \"speech.mp3"
  },
  {
    "path": "tests/test_openai_vl_chat.py",
    "chars": 1399,
    "preview": "import base64\nfrom openai import OpenAI\nfrom pathlib import Path\n\n\ndef image_to_base64(image_path):\n    \"\"\"将图片转换为Base64字"
  },
  {
    "path": "tests/test_perf.py",
    "chars": 603,
    "preview": "from evalscope.perf.arguments import Arguments\nfrom evalscope.perf.main import run_perf_benchmark\nfrom rich import print"
  },
  {
    "path": "tests/test_rerank.py",
    "chars": 663,
    "preview": "\"\"\"支持 dify 等开源项目\"\"\"\n\nimport requests\nfrom rich import print\n\n\ndef rerank():\n    url = f\"http://localhost:8082/v1/rerank\""
  },
  {
    "path": "tests/vllm/embedding.py",
    "chars": 1937,
    "preview": "import asyncio\nfrom vllm.engine.arg_utils import AsyncEngineArgs\nfrom vllm.engine.async_llm_engine import AsyncLLMEngine"
  },
  {
    "path": "tests/vllm/models.py",
    "chars": 3043,
    "preview": "import asyncio\nfrom vllm.engine.arg_utils import AsyncEngineArgs\nfrom vllm.engine.async_llm_engine import AsyncLLMEngine"
  }
]

About this extraction

This page contains the full source code of the shell-nlp/gpt_server GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 93 files (347.1 KB), approximately 85.6k tokens, and a symbol index with 370 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!