[
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.yaml",
    "content": "name: \"\\U0001F41B Bug Report\"\ndescription: Submit a bug report to help us improve ChatGLM3 / 提交一个 Bug 问题报告来帮助我们改进 ChatGLM3\nbody:\n  - type: textarea\n    id: system-info\n    attributes:\n      label: System Info / 系統信息\n      description: Your operating environment / 您的运行环境信息\n      placeholder: Includes Cuda version, Transformers version, Python version, operating system, hardware information (if you suspect a hardware problem)... / 包括Cuda版本，Transformers版本，Python版本，操作系统，硬件信息(如果您怀疑是硬件方面的问题)...\n    validations:\n      required: true\n\n  - type: textarea\n    id: who-can-help\n    attributes:\n      label: Who can help? / 谁可以帮助到您？\n      description: |\n        Your issue will be replied to more quickly if you can figure out the right person to tag with @\n        All issues are read by one of the maintainers, so if you don't know who to tag, just leave this blank and our maintainer will ping the right person.\n    \n        Please tag fewer than 3 people.\n        \n        如果您能找到合适的标签 @，您的问题会更快得到回复。\n        所有问题都会由我们的维护者阅读，如果您不知道该标记谁，只需留空，我们的维护人员会找到合适的开发组成员来解决问题。\n        \n        标记的人数应该不超过 3 个人。\n\n        Related demo leader / 相关demo负责人 :\n        - finetune_demo: @Btlmd\n        - langchain_demo: @yincf\n        - composite_demo: @abmfy\n        \n        If it's not a bug in these three subsections, you may not specify the helper. Our maintainer will find the right person in the development group to solve the problem.\n        \n        如果不是这三个子版块的bug，您可以不指明帮助者，我们的维护人员会找到合适的开发组成员来解决问题。\n\n      placeholder: \"@Username ...\"\n\n  - type: checkboxes\n    id: information-scripts-examples\n    attributes:\n      label: Information / 问题信息\n      description: 'The problem arises when using: / 问题出现在'\n      options:\n        - label: \"The official example scripts / 官方的示例脚本\"\n        - label: \"My own modified scripts / 我自己修改的脚本和任务\"\n\n  - type: textarea\n    id: reproduction\n    validations:\n      required: true\n    attributes:\n      label: Reproduction / 复现过程\n      description: |\n        Please provide a code example that reproduces the problem you encountered, preferably with a minimal reproduction unit.\n        If you have code snippets, error messages, stack traces, please provide them here as well.\n        Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting\n        Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code.\n        \n        请提供能重现您遇到的问题的代码示例,最好是最小复现单元。\n        如果您有代码片段、错误信息、堆栈跟踪，也请在此提供。\n        请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting\n        请勿使用截图，因为截图难以阅读，而且（更重要的是）不允许他人复制粘贴您的代码。\n      placeholder: |\n        Steps to reproduce the behavior/复现Bug的步骤:\n          \n          1.\n          2.\n          3.\n\n  - type: textarea\n    id: expected-behavior\n    validations:\n      required: true\n    attributes:\n      label: Expected behavior / 期待表现\n      description: \"A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。\""
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature-request.yaml",
    "content": "name: \"\\U0001F680 Feature request\"\ndescription: Submit a request for a new ChatGLM3 feature / 提交一个新的 ChatGLM3 的功能建议\nlabels: [ \"feature\" ]\nbody:\n  - type: textarea\n    id: feature-request\n    validations:\n      required: true\n    attributes:\n      label: Feature request  / 功能建议\n      description: |\n        A brief description of the functional proposal. Links to corresponding papers and code are desirable.\n        对功能建议的简述。最好提供对应的论文和代码链接\n\n  - type: textarea\n    id: motivation\n    validations:\n      required: true\n    attributes:\n      label: Motivation / 动机\n      description: |\n        Your motivation for making the suggestion. If that motivation is related to another GitHub issue, link to it here.\n        您提出建议的动机。如果该动机与另一个 GitHub 问题有关，请在此处提供对应的链接。\n\n  - type: textarea\n    id: contribution\n    validations:\n      required: true\n    attributes:\n      label: Your contribution / 您的贡献\n      description: |\n        \n        Your PR link or any other link you can help with.\n        您的PR链接或者其他您能提供帮助的链接。"
  },
  {
    "path": ".github/PULL_REQUEST_TEMPLATE/pr_template.md",
    "content": "#  Raise valuable PR / 提出有价值的PR\n\n## Caution/ 注意事项:\nUsers should keep the following points in mind when submitting PRs:\n\n1. The proposed PR should be about this project. \n2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs.\n\n用户在提交PR时候应该注意以下几点:\n\n1. 提出的PR应该是关于本项目的。\n2. 提出的PR应该具有针对性，如果具有多个不同的想法和优化方案，应该分配到不同的PR中。\n\n## 不应该提出的PR / PRs that should not be proposed\n\nIf a developer proposes a PR about any of the following, it may be closed or Rejected.\n\n1. those that don't describe improvement options.\n2. multiple issues of different types combined in one PR.\n3. The proposed PR is highly duplicative of already existing PRs.\n\n如果开发者提出关于以下方面的PR，则可能会被直接关闭或拒绝通过。\n\n1. 没有说明改进方案的。\n2. 多个不同类型的问题合并在一个PR中的。\n3. 提出的PR与已经存在的PR高度重复的。\n\n\n# 检查您的PR\n- [ ] Have you read the Contributor Guidelines, Pull Request section? / 您是否阅读了贡献者指南、Pull Request 部分？\n- [ ] Has this been discussed/approved via a Github issue or forum? If so, add a link. / 是否通过 Github 问题或论坛讨论/批准过？如果是，请添加链接。\n- [ ] Did you make sure you updated the documentation with your changes? Here are the Documentation Guidelines, and here are the Documentation Formatting Tips. /您是否确保根据您的更改更新了文档？这里是文档指南，这里是文档格式化技巧。\n- [ ] Did you write new required tests? / 您是否编写了新的必要测试？\n- [ ]  Are your PRs for only one issue / 您的PR是否仅针对一个问题"
  },
  {
    "path": ".gitignore",
    "content": "__pycache__\n\n# finetune_demo: generated & downloaded files\nfinetune_demo/output\nfinetune_demo/data\nfinetune_demo/formatted_data\nToolAlpaca/\nAdvertiseGen/\n*.gz\n*.idea\n.DS_Store"
  },
  {
    "path": "DEPLOYMENT.md",
    "content": "## 低成本部署\n\n### 模型量化\n\n默认情况下，模型以 FP16 精度加载，运行上述代码需要大概 13GB 显存。如果你的 GPU 显存有限，可以尝试以量化方式加载模型，使用方法如下：\n\n```python\nmodel = AutoModel.from_pretrained(\"THUDM/chatglm3-6b\",trust_remote_code=True).quantize(4).cuda()\n```\n\n模型量化会带来一定的性能损失，经过测试，ChatGLM3-6B 在 4-bit 量化下仍然能够进行自然流畅的生成。\n\n### CPU 部署\n\n如果你没有 GPU 硬件的话，也可以在 CPU 上进行推理，但是推理速度会更慢。使用方法如下（需要大概 32GB 内存）\n```python\nmodel = AutoModel.from_pretrained(\"THUDM/chatglm3-6b\", trust_remote_code=True).float()\n```\n\n### Mac 部署\n\n对于搭载了 Apple Silicon 或者 AMD GPU 的 Mac，可以使用 MPS 后端来在 GPU 上运行 ChatGLM3-6B。需要参考 Apple 的 [官方说明](https://developer.apple.com/metal/pytorch) 安装 PyTorch-Nightly（正确的版本号应该是2.x.x.dev2023xxxx，而不是 2.x.x）。\n\n目前在 MacOS 上只支持[从本地加载模型](README.md#从本地加载模型)。将代码中的模型加载改为从本地加载，并使用 mps 后端：\n```python\nmodel = AutoModel.from_pretrained(\"your local path\", trust_remote_code=True).to('mps')\n```\n\n加载半精度的 ChatGLM3-6B 模型需要大概 13GB 内存。内存较小的机器（比如 16GB 内存的 MacBook Pro），在空余内存不足的情况下会使用硬盘上的虚拟内存，导致推理速度严重变慢。\n\n### 多卡部署\n如果你有多张 GPU，但是每张 GPU 的显存大小都不足以容纳完整的模型，那么可以将模型切分在多张GPU上。首先安装 accelerate: `pip install accelerate`，然后通过如下方法加载模型：\n\n```python\nfrom utils import load_model_on_gpus\n\nmodel = load_model_on_gpus(\"THUDM/chatglm3-6b\", num_gpus=2)\n```\n即可将模型部署到两张 GPU 上进行推理。你可以将 `num_gpus` 改为你希望使用的 GPU 数。默认是均匀切分的，你也可以传入 `device_map` 参数来自己指定。 "
  },
  {
    "path": "DEPLOYMENT_en.md",
    "content": "## Low-Cost Deployment\n\n### Model Quantization\n\nBy default, the model is loaded with FP16 precision, running the above code requires about 13GB of VRAM. If your GPU's VRAM is limited, you can try loading the model quantitatively, as follows:\n\n```python\nmodel = AutoModel.from_pretrained(\"THUDM/chatglm3-6b\",trust_remote_code=True).quantize(4).cuda()\n```\n\nModel quantization will bring some performance loss. Through testing, ChatGLM3-6B can still perform natural and smooth generation under 4-bit quantization.\n\n### CPU Deployment\n\nIf you don't have GPU hardware, you can also run inference on the CPU, but the inference speed will be slower. The usage is as follows (requires about 32GB of memory):\n\n```python\nmodel = AutoModel.from_pretrained(\"THUDM/chatglm3-6b\", trust_remote_code=True).float()\n```\n\n### Mac Deployment\n\nFor Macs equipped with Apple Silicon or AMD GPUs, the MPS backend can be used to run ChatGLM3-6B on the GPU. Refer to Apple's [official instructions](https://developer.apple.com/metal/pytorch) to install PyTorch-Nightly (the correct version number should be 2.x.x.dev2023xxxx, not 2.x.x).\n\nCurrently, only [loading the model locally](README_en.md#load-model-locally) is supported on MacOS. Change the model loading in the code to load locally and use the MPS backend:\n\n```python\nmodel = AutoModel.from_pretrained(\"your local path\", trust_remote_code=True).to('mps')\n```\n\nLoading the half-precision ChatGLM3-6B model requires about 13GB of memory. Machines with smaller memory (such as a 16GB memory MacBook Pro) will use virtual memory on the hard disk when there is insufficient free memory, resulting in a significant slowdown in inference speed.\n\n### Multi-GPU Deployment\n\nIf you have multiple GPUs, but each GPU's VRAM size is not enough to accommodate the complete model, then the model can be split across multiple GPUs. First, install accelerate: `pip install accelerate`, and then load the model through the following methods:\n\n```python\nfrom utils import load_model_on_gpus\n\nmodel = load_model_on_gpus(\"THUDM/chatglm3-6b\", num_gpus=2)\n```\n\nThis allows the model to be deployed on two GPUs for inference. You can change `num_gpus` to the number of GPUs you want to use. It is evenly split by default, but you can also pass the `device_map` parameter to specify it yourself."
  },
  {
    "path": "Intel_device_demo/README.md",
    "content": "# Intel Device Demo\n\n本文件夹主要辅助开发者 在 Intel 设备上加速部署 ChatGLM3-6B 模型。\n\n## 1. 硬件准备\n本文件夹中的设备支持列表包括：\n- Intel CPU 系列, 包含个人CPU 和 服务器 / 工作站 CPU\n- Intel Arc 独立显卡系列，包括 Arc A770 等显卡。\n- Intel CPU 核显系列\n- 其他理论支持 OpenVINO 加速的Intel 工具套件。\n\n## 2. 文件目录\n- IPEX_llm_xxx_demo: IPEX-LLM 是一个为Intel XPU(Xeon/Core/Flex/Arc/PVC)打造的低精度轻量级大语言模型库，在Intel平台上具有广泛的模型支持、最低的延迟和最小的内存占用，实现加速模型部署示例。\n- OpenVINO_demo: 使用 Intel OpenVINO 推理加速框架，实现加速模型部署示例。\n- Pytorch_demo (暂未推出) : 使用 Intel Pytorch Extension 实现在 Pytorch 环境上开发（适用于 Intel Arc 系列 GPU）\n\n"
  },
  {
    "path": "Intel_device_demo/ipex_llm_cpu_demo/api_server.py",
    "content": "\"\"\"\nThis script implements an API for the ChatGLM3-6B model,\nformatted similarly to OpenAI's API (https://platform.openai.com/docs/api-reference/chat).\nIt's designed to be run as a web server using FastAPI and uvicorn,\nmaking the ChatGLM3-6B model accessible through OpenAI Client.\n\nKey Components and Features:\n- Model and Tokenizer Setup: Configures the model and tokenizer paths and loads them.\n- FastAPI Configuration: Sets up a FastAPI application with CORS middleware for handling cross-origin requests.\n- API Endpoints:\n  - \"/v1/models\": Lists the available models, specifically ChatGLM3-6B.\n  - \"/v1/chat/completions\": Processes chat completion requests with options for streaming and regular responses.\n  - \"/v1/embeddings\": Processes Embedding request of a list of text inputs.\n- Token Limit Caution: In the OpenAI API, 'max_tokens' is equivalent to HuggingFace's 'max_new_tokens', not 'max_length'.\nFor instance, setting 'max_tokens' to 8192 for a 6b model would result in an error due to the model's inability to output\nthat many tokens after accounting for the history and prompt tokens.\n- Stream Handling and Custom Functions: Manages streaming responses and custom function calls within chat responses.\n- Pydantic Models: Defines structured models for requests and responses, enhancing API documentation and type safety.\n- Main Execution: Initializes the model and tokenizer, and starts the FastAPI app on the designated host and port.\n\nNote:\n    This script doesn't include the setup for special tokens or multi-GPU support by default.\n    Users need to configure their special tokens and can enable multi-GPU support as per the provided instructions.\n    Embedding Models only support in One GPU.\n\n\"\"\"\n\nimport os\nimport time\nimport tiktoken\nimport torch\nimport uvicorn\n\nfrom fastapi import FastAPI, HTTPException, Response\nfrom fastapi.middleware.cors import CORSMiddleware\n\nfrom contextlib import asynccontextmanager\nfrom typing import List, Literal, Optional, Union\nfrom loguru import logger\nfrom pydantic import BaseModel, Field\nfrom ipex_llm.transformers import AutoModel\nfrom transformers import AutoTokenizer\nfrom utils import process_response, generate_chatglm3, generate_stream_chatglm3\n# from sentence_transformers import SentenceTransformer\n\nfrom sse_starlette.sse import EventSourceResponse\n\n# Set up limit request time\nEventSourceResponse.DEFAULT_PING_INTERVAL = 1000\n\n# set LLM path\nMODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')\nTOKENIZER_PATH = os.environ.get(\"TOKENIZER_PATH\", MODEL_PATH)\n\n# set Embedding Model path\nEMBEDDING_PATH = os.environ.get('EMBEDDING_PATH', 'BAAI/bge-large-zh-v1.5')\n\n\n@asynccontextmanager\nasync def lifespan(app: FastAPI):\n    yield\n    if torch.cuda.is_available():\n        torch.cuda.empty_cache()\n        torch.cuda.ipc_collect()\n\n\napp = FastAPI(lifespan=lifespan)\n\napp.add_middleware(\n    CORSMiddleware,\n    allow_origins=[\"*\"],\n    allow_credentials=True,\n    allow_methods=[\"*\"],\n    allow_headers=[\"*\"],\n)\n\n\nclass ModelCard(BaseModel):\n    id: str\n    object: str = \"model\"\n    created: int = Field(default_factory=lambda: int(time.time()))\n    owned_by: str = \"owner\"\n    root: Optional[str] = None\n    parent: Optional[str] = None\n    permission: Optional[list] = None\n\n\nclass ModelList(BaseModel):\n    object: str = \"list\"\n    data: List[ModelCard] = []\n\n\nclass FunctionCallResponse(BaseModel):\n    name: Optional[str] = None\n    arguments: Optional[str] = None\n\n\nclass ChatMessage(BaseModel):\n    role: Literal[\"user\", \"assistant\", \"system\", \"function\"]\n    content: str = None\n    name: Optional[str] = None\n    function_call: Optional[FunctionCallResponse] = None\n\n\nclass DeltaMessage(BaseModel):\n    role: Optional[Literal[\"user\", \"assistant\", \"system\"]] = None\n    content: Optional[str] = None\n    function_call: Optional[FunctionCallResponse] = None\n\n\n## for Embedding\nclass EmbeddingRequest(BaseModel):\n    input: List[str]\n    model: str\n\n\nclass CompletionUsage(BaseModel):\n    prompt_tokens: int\n    completion_tokens: int\n    total_tokens: int\n\n\nclass EmbeddingResponse(BaseModel):\n    data: list\n    model: str\n    object: str\n    usage: CompletionUsage\n\n\n# for ChatCompletionRequest\n\nclass UsageInfo(BaseModel):\n    prompt_tokens: int = 0\n    total_tokens: int = 0\n    completion_tokens: Optional[int] = 0\n\n\nclass ChatCompletionRequest(BaseModel):\n    model: str\n    messages: List[ChatMessage]\n    temperature: Optional[float] = 0.8\n    top_p: Optional[float] = 0.8\n    max_tokens: Optional[int] = None\n    stream: Optional[bool] = False\n    tools: Optional[Union[dict, List[dict]]] = None\n    repetition_penalty: Optional[float] = 1.1\n\n\nclass ChatCompletionResponseChoice(BaseModel):\n    index: int\n    message: ChatMessage\n    finish_reason: Literal[\"stop\", \"length\", \"function_call\"]\n\n\nclass ChatCompletionResponseStreamChoice(BaseModel):\n    delta: DeltaMessage\n    finish_reason: Optional[Literal[\"stop\", \"length\", \"function_call\"]]\n    index: int\n\n\nclass ChatCompletionResponse(BaseModel):\n    model: str\n    id: str\n    object: Literal[\"chat.completion\", \"chat.completion.chunk\"]\n    choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]\n    created: Optional[int] = Field(default_factory=lambda: int(time.time()))\n    usage: Optional[UsageInfo] = None\n\n\n@app.get(\"/health\")\nasync def health() -> Response:\n    \"\"\"Health check.\"\"\"\n    return Response(status_code=200)\n\n\n@app.post(\"/v1/embeddings\", response_model=EmbeddingResponse)\nasync def get_embeddings(request: EmbeddingRequest):\n    embeddings = [embedding_model.encode(text) for text in request.input]\n    embeddings = [embedding.tolist() for embedding in embeddings]\n\n    def num_tokens_from_string(string: str) -> int:\n        \"\"\"\n        Returns the number of tokens in a text string.\n        use cl100k_base tokenizer\n        \"\"\"\n        encoding = tiktoken.get_encoding('cl100k_base')\n        num_tokens = len(encoding.encode(string))\n        return num_tokens\n\n    response = {\n        \"data\": [\n            {\n                \"object\": \"embedding\",\n                \"embedding\": embedding,\n                \"index\": index\n            }\n            for index, embedding in enumerate(embeddings)\n        ],\n        \"model\": request.model,\n        \"object\": \"list\",\n        \"usage\": CompletionUsage(\n            prompt_tokens=sum(len(text.split()) for text in request.input),\n            completion_tokens=0,\n            total_tokens=sum(num_tokens_from_string(text) for text in request.input),\n        )\n    }\n    return response\n\n\n@app.get(\"/v1/models\", response_model=ModelList)\nasync def list_models():\n    model_card = ModelCard(\n        id=\"chatglm3-6b\"\n    )\n    return ModelList(\n        data=[model_card]\n    )\n\n\n@app.post(\"/v1/chat/completions\", response_model=ChatCompletionResponse)\nasync def create_chat_completion(request: ChatCompletionRequest):\n    global model, tokenizer\n\n    if len(request.messages) < 1 or request.messages[-1].role == \"assistant\":\n        raise HTTPException(status_code=400, detail=\"Invalid request\")\n\n    gen_params = dict(\n        messages=request.messages,\n        temperature=request.temperature,\n        top_p=request.top_p,\n        max_tokens=request.max_tokens or 1024,\n        echo=False,\n        stream=request.stream,\n        repetition_penalty=request.repetition_penalty,\n        tools=request.tools,\n    )\n    logger.debug(f\"==== request ====\\n{gen_params}\")\n\n    if request.stream:\n\n        # Use the stream mode to read the first few characters, if it is not a function call, direct stram output\n        predict_stream_generator = predict_stream(request.model, gen_params)\n        output = next(predict_stream_generator)\n        if not contains_custom_function(output):\n            return EventSourceResponse(predict_stream_generator, media_type=\"text/event-stream\")\n\n        # Obtain the result directly at one time and determine whether tools needs to be called.\n        logger.debug(f\"First result output：\\n{output}\")\n\n        function_call = None\n        if output and request.tools:\n            try:\n                function_call = process_response(output, use_tool=True)\n            except:\n                logger.warning(\"Failed to parse tool call\")\n\n        # CallFunction\n        if isinstance(function_call, dict):\n            function_call = FunctionCallResponse(**function_call)\n\n            \"\"\"\n            In this demo, we did not register any tools.\n            You can use the tools that have been implemented in our `tools_using_demo` and implement your own streaming tool implementation here.\n            Similar to the following method:\n                function_args = json.loads(function_call.arguments)\n                tool_response = dispatch_tool(tool_name: str, tool_params: dict)\n            \"\"\"\n            tool_response = \"\"\n\n            if not gen_params.get(\"messages\"):\n                gen_params[\"messages\"] = []\n\n            gen_params[\"messages\"].append(ChatMessage(\n                role=\"assistant\",\n                content=output,\n            ))\n            gen_params[\"messages\"].append(ChatMessage(\n                role=\"function\",\n                name=function_call.name,\n                content=tool_response,\n            ))\n\n            # Streaming output of results after function calls\n            generate = predict(request.model, gen_params)\n            return EventSourceResponse(generate, media_type=\"text/event-stream\")\n\n        else:\n            # Handled to avoid exceptions in the above parsing function process.\n            generate = parse_output_text(request.model, output)\n            return EventSourceResponse(generate, media_type=\"text/event-stream\")\n\n    # Here is the handling of stream = False\n    response = generate_chatglm3(model, tokenizer, gen_params)\n\n    # Remove the first newline character\n    if response[\"text\"].startswith(\"\\n\"):\n        response[\"text\"] = response[\"text\"][1:]\n    response[\"text\"] = response[\"text\"].strip()\n\n    usage = UsageInfo()\n    function_call, finish_reason = None, \"stop\"\n    if request.tools:\n        try:\n            function_call = process_response(response[\"text\"], use_tool=True)\n        except:\n            logger.warning(\"Failed to parse tool call, maybe the response is not a tool call or have been answered.\")\n\n    if isinstance(function_call, dict):\n        finish_reason = \"function_call\"\n        function_call = FunctionCallResponse(**function_call)\n\n    message = ChatMessage(\n        role=\"assistant\",\n        content=response[\"text\"],\n        function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,\n    )\n\n    logger.debug(f\"==== message ====\\n{message}\")\n\n    choice_data = ChatCompletionResponseChoice(\n        index=0,\n        message=message,\n        finish_reason=finish_reason,\n    )\n    task_usage = UsageInfo.model_validate(response[\"usage\"])\n    for usage_key, usage_value in task_usage.model_dump().items():\n        setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)\n\n    return ChatCompletionResponse(\n        model=request.model,\n        id=\"\",  # for open_source model, id is empty\n        choices=[choice_data],\n        object=\"chat.completion\",\n        usage=usage\n    )\n\n\nasync def predict(model_id: str, params: dict):\n    global model, tokenizer\n\n    choice_data = ChatCompletionResponseStreamChoice(\n        index=0,\n        delta=DeltaMessage(role=\"assistant\"),\n        finish_reason=None\n    )\n    chunk = ChatCompletionResponse(model=model_id, id=\"\", choices=[choice_data], object=\"chat.completion.chunk\")\n    yield \"{}\".format(chunk.model_dump_json(exclude_unset=True))\n\n    previous_text = \"\"\n    for new_response in generate_stream_chatglm3(model, tokenizer, params):\n        decoded_unicode = new_response[\"text\"]\n        delta_text = decoded_unicode[len(previous_text):]\n        previous_text = decoded_unicode\n\n        finish_reason = new_response[\"finish_reason\"]\n        if len(delta_text) == 0 and finish_reason != \"function_call\":\n            continue\n\n        function_call = None\n        if finish_reason == \"function_call\":\n            try:\n                function_call = process_response(decoded_unicode, use_tool=True)\n            except:\n                logger.warning(\n                    \"Failed to parse tool call, maybe the response is not a tool call or have been answered.\")\n\n        if isinstance(function_call, dict):\n            function_call = FunctionCallResponse(**function_call)\n\n        delta = DeltaMessage(\n            content=delta_text,\n            role=\"assistant\",\n            function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,\n        )\n\n        choice_data = ChatCompletionResponseStreamChoice(\n            index=0,\n            delta=delta,\n            finish_reason=finish_reason\n        )\n        chunk = ChatCompletionResponse(\n            model=model_id,\n            id=\"\",\n            choices=[choice_data],\n            object=\"chat.completion.chunk\"\n        )\n        yield \"{}\".format(chunk.model_dump_json(exclude_unset=True))\n\n    choice_data = ChatCompletionResponseStreamChoice(\n        index=0,\n        delta=DeltaMessage(),\n        finish_reason=\"stop\"\n    )\n    chunk = ChatCompletionResponse(\n        model=model_id,\n        id=\"\",\n        choices=[choice_data],\n        object=\"chat.completion.chunk\"\n    )\n    yield \"{}\".format(chunk.model_dump_json(exclude_unset=True))\n    yield '[DONE]'\n\n\ndef predict_stream(model_id, gen_params):\n    \"\"\"\n    The function call is compatible with stream mode output.\n\n    The first seven characters are determined.\n    If not a function call, the stream output is directly generated.\n    Otherwise, the complete character content of the function call is returned.\n\n    :param model_id:\n    :param gen_params:\n    :return:\n    \"\"\"\n    output = \"\"\n    is_function_call = False\n    has_send_first_chunk = False\n    for new_response in generate_stream_chatglm3(model, tokenizer, gen_params):\n        decoded_unicode = new_response[\"text\"]\n        delta_text = decoded_unicode[len(output):]\n        output = decoded_unicode\n\n        # When it is not a function call and the character length is> 7,\n        # try to judge whether it is a function call according to the special function prefix\n        if not is_function_call and len(output) > 7:\n\n            # Determine whether a function is called\n            is_function_call = contains_custom_function(output)\n            if is_function_call:\n                continue\n\n            # Non-function call, direct stream output\n            finish_reason = new_response[\"finish_reason\"]\n\n            # Send an empty string first to avoid truncation by subsequent next() operations.\n            if not has_send_first_chunk:\n                message = DeltaMessage(\n                    content=\"\",\n                    role=\"assistant\",\n                    function_call=None,\n                )\n                choice_data = ChatCompletionResponseStreamChoice(\n                    index=0,\n                    delta=message,\n                    finish_reason=finish_reason\n                )\n                chunk = ChatCompletionResponse(\n                    model=model_id,\n                    id=\"\",\n                    choices=[choice_data],\n                    created=int(time.time()),\n                    object=\"chat.completion.chunk\"\n                )\n                yield \"{}\".format(chunk.model_dump_json(exclude_unset=True))\n\n            send_msg = delta_text if has_send_first_chunk else output\n            has_send_first_chunk = True\n            message = DeltaMessage(\n                content=send_msg,\n                role=\"assistant\",\n                function_call=None,\n            )\n            choice_data = ChatCompletionResponseStreamChoice(\n                index=0,\n                delta=message,\n                finish_reason=finish_reason\n            )\n            chunk = ChatCompletionResponse(\n                model=model_id,\n                id=\"\",\n                choices=[choice_data],\n                created=int(time.time()),\n                object=\"chat.completion.chunk\"\n            )\n            yield \"{}\".format(chunk.model_dump_json(exclude_unset=True))\n\n    if is_function_call:\n        yield output\n    else:\n        yield '[DONE]'\n\n\nasync def parse_output_text(model_id: str, value: str):\n    \"\"\"\n    Directly output the text content of value\n\n    :param model_id:\n    :param value:\n    :return:\n    \"\"\"\n    choice_data = ChatCompletionResponseStreamChoice(\n        index=0,\n        delta=DeltaMessage(role=\"assistant\", content=value),\n        finish_reason=None\n    )\n    chunk = ChatCompletionResponse(model=model_id, id=\"\", choices=[choice_data], object=\"chat.completion.chunk\")\n    yield \"{}\".format(chunk.model_dump_json(exclude_unset=True))\n\n    choice_data = ChatCompletionResponseStreamChoice(\n        index=0,\n        delta=DeltaMessage(),\n        finish_reason=\"stop\"\n    )\n    chunk = ChatCompletionResponse(model=model_id, id=\"\", choices=[choice_data], object=\"chat.completion.chunk\")\n    yield \"{}\".format(chunk.model_dump_json(exclude_unset=True))\n    yield '[DONE]'\n\n\ndef contains_custom_function(value: str) -> bool:\n    \"\"\"\n    Determine whether 'function_call' according to a special function prefix.\n\n    For example, the functions defined in \"tools_using_demo/tool_register.py\" are all \"get_xxx\" and start with \"get_\"\n\n    [Note] This is not a rigorous judgment method, only for reference.\n\n    :param value:\n    :return:\n    \"\"\"\n    return value and 'get_' in value\n\nif __name__ == \"__main__\":\n    # Load LLM\n    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)\n    model = AutoModel.from_pretrained(MODEL_PATH,\n                                      load_in_4bit=True,\n                                      trust_remote_code=True)\n    # load Embedding\n    # embedding_model = SentenceTransformer(EMBEDDING_PATH, device=\"cuda\")\n    uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)\n"
  },
  {
    "path": "Intel_device_demo/ipex_llm_cpu_demo/chatglm3_infer.py",
    "content": "import time\nfrom ipex_llm.transformers import AutoModel\nfrom transformers import AutoTokenizer\n\nCHATGLM_V3_PROMPT_FORMAT = \"\\n{prompt}\\n\"\n\n# Please specify the local path to the chatglm3-6b model\n\nmodel_path = \"D:\\AI\\ChatGLM3\\model/chatglm3-6b/\"\n\n# Load the ChatGLM3-6B model and quantize it to INT4\nmodel = AutoModel.from_pretrained(model_path,\n                                  load_in_4bit=True,\n                                  trust_remote_code=True)\n# Load the tokenizer\ntokenizer = AutoTokenizer.from_pretrained(model_path,\n                                          trust_remote_code=True)\n# Prepare ChatGLM3 format prompt\nprompt = CHATGLM_V3_PROMPT_FORMAT.format(prompt=\"Who are you?\")\n# Encode the prompt\ninput_ids = tokenizer.encode(prompt, return_tensors=\"pt\")\nst = time.time()\n# Perform inference calculation and generate Tokens\noutput = model.generate(input_ids,max_new_tokens=32)\nend = time.time()\n# Decode the generated Tokens and display them\noutput_str = tokenizer.decode(output[0], skip_special_tokens=True)\nprint(f'Inference time: {end-st} s')\nprint('-'*20, 'Prompt', '-'*20)\nprint(prompt)\nprint('-'*20, 'Output', '-'*20)\nprint(output_str)\n"
  },
  {
    "path": "Intel_device_demo/ipex_llm_cpu_demo/chatglm3_web_demo.py",
    "content": "\"\"\"\nThis script creates an interactive web demo for the ChatGLM3-6B model using Gradio,\na Python library for building quick and easy UI components for machine learning models.\nIt's designed to showcase the capabilities of the ChatGLM3-6B model in a user-friendly interface,\nallowing users to interact with the model through a chat-like interface.\n\nUsage:\n- Run the script to start the Gradio web server.\n- Interact with the model by typing questions and receiving responses.\n\nRequirements:\n- Gradio (required for 4.13.0 and later, 3.x is not support now) should be installed.\n\nNote: The script includes a modification to the Chatbot's postprocess method to handle markdown to HTML conversion,\nensuring that the chat interface displays formatted text correctly.\n\n\"\"\"\n\nimport os\nimport streamlit as st\nfrom ipex_llm.transformers import AutoModel\nfrom transformers import AutoTokenizer\n\n\nst.set_page_config(\n    page_title=\"ChatGLM3-6B+BigDL-LLM demo\",\n    page_icon=\":robot:\",\n    layout=\"wide\"\n)\n\nMODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')\n\n@st.cache_resource\ndef get_model():\n    model = AutoModel.from_pretrained(MODEL_PATH,\n                                    load_in_4bit=True,\n                                    trust_remote_code=True)\n    \n    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH,\n                                            trust_remote_code=True)\n    return tokenizer, model\n\ntokenizer, model = get_model()\n\nif \"history\" not in st.session_state:\n    st.session_state.history = []\nif \"past_key_values\" not in st.session_state:\n    st.session_state.past_key_values = None\n\nmax_length = st.sidebar.slider(\"max_length\", 0, 32768, 8192, step=1)\ntop_p = st.sidebar.slider(\"top_p\", 0.0, 1.0, 0.8, step=0.01)\ntemperature = st.sidebar.slider(\"temperature\", 0.0, 1.0, 0.6, step=0.01)\n\nbuttonClean = st.sidebar.button(\"clearing session history\", key=\"clean\")\nif buttonClean:\n    st.session_state.history = []\n    st.session_state.past_key_values = None\n    st.rerun()\n\nfor i, message in enumerate(st.session_state.history):\n    if message[\"role\"] == \"user\":\n        with st.chat_message(name=\"user\", avatar=\"user\"):\n            st.markdown(message[\"content\"])\n    else:\n        with st.chat_message(name=\"assistant\", avatar=\"assistant\"):\n            st.markdown(message[\"content\"])\n\nwith st.chat_message(name=\"user\", avatar=\"user\"):\n    input_placeholder = st.empty()\nwith st.chat_message(name=\"assistant\", avatar=\"assistant\"):\n    message_placeholder = st.empty()\n\nprompt_text = st.chat_input(\"please enter your question.\")\n\nif prompt_text:\n\n    input_placeholder.markdown(prompt_text)\n    history = st.session_state.history\n    past_key_values = st.session_state.past_key_values\n    for response, history, past_key_values in model.stream_chat(\n        tokenizer,\n        prompt_text,\n        history,\n        past_key_values=past_key_values,\n        max_length=max_length,\n        top_p=top_p,\n        temperature=temperature,\n        return_past_key_values=True,\n    ):\n        message_placeholder.markdown(response)\n\n    st.session_state.history = history\n    st.session_state.past_key_values = past_key_values"
  },
  {
    "path": "Intel_device_demo/ipex_llm_cpu_demo/generate.py",
    "content": "import torch\nimport time\nimport argparse\nimport numpy as np\n\nfrom ipex_llm.transformers import AutoModel\nfrom modelscope import AutoTokenizer\nfrom transformers import AutoTokenizer\n\n# you could tune the prompt based on your own model,\n# here the prompt tuning refers to https://github.com/THUDM/ChatGLM3/blob/main/PROMPT.md\nCHATGLM_V3_PROMPT_FORMAT = \"<|user|>\\n{prompt}\\n<|assistant|>\"\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for ModelScope ChatGLM3 model')\n    parser.add_argument('--repo-id-or-model-path', type=str, default=\"ZhipuAI/chatglm3-6b\",\n                        help='The ModelScope repo id for the ChatGLM3 model to be downloaded'\n                             ', or the path to the ModelScope checkpoint folder')\n    parser.add_argument('--prompt', type=str, default=\"AI是什么？\",\n                        help='Prompt to infer')\n    parser.add_argument('--n-predict', type=int, default=32,\n                        help='Max tokens to predict')\n\n    args = parser.parse_args()\n    model_path = args.repo_id_or_model_path\n\n    # Load model in 4 bit,\n    # which convert the relevant layers in the model into INT4 format\n    # It is important to set `model_hub='modelscope'`, otherwise model hub is default to be huggingface\n    model = AutoModel.from_pretrained(model_path,\n                                      load_in_4bit=True,\n                                      trust_remote_code=True,\n                                      model_hub='modelscope')\n\n    # Load tokenizer\n    tokenizer = AutoTokenizer.from_pretrained(model_path,\n                                              trust_remote_code=True)\n\n    # Generate predicted tokens\n    with torch.inference_mode():\n        prompt = CHATGLM_V3_PROMPT_FORMAT.format(prompt=args.prompt)\n        input_ids = tokenizer.encode(prompt, return_tensors=\"pt\")\n        st = time.time()\n        # if your selected model is capable of utilizing previous key/value attentions\n        # to enhance decoding speed, but has `\"use_cache\": false` in its model config,\n        # it is important to set `use_cache=True` explicitly in the `generate` function\n        # to obtain optimal performance with IPEX-LLM INT4 optimizations\n        output = model.generate(input_ids,\n                                max_new_tokens=args.n_predict)\n        end = time.time()\n        output_str = tokenizer.decode(output[0], skip_special_tokens=True)\n        print(f'Inference time: {end - st} s')\n        print('-' * 20, 'Prompt', '-' * 20)\n        print(prompt)\n        print('-' * 20, 'Output', '-' * 20)\n        print(output_str)"
  },
  {
    "path": "Intel_device_demo/ipex_llm_cpu_demo/openai_api_request.py",
    "content": "\"\"\"\nThis script is an example of using the OpenAI API to create various interactions with a ChatGLM3 model.\nIt includes functions to:\n\n1. Conduct a basic chat session, asking about weather conditions in multiple cities.\n2. Initiate a simple chat in Chinese, asking the model to tell a short story.\n3. Retrieve and print embeddings for a given text input.\n\nEach function demonstrates a different aspect of the API's capabilities, showcasing how to make requests\nand handle responses.\n\"\"\"\n\nfrom openai import OpenAI\nimport time\n\nbase_url = \"http://127.0.0.1:8000/v1/\"\nclient = OpenAI(api_key=\"EMPTY\", base_url=base_url)\n\n\ndef function_chat():\n    messages = [{\"role\": \"user\", \"content\": \"What's the weather like in San Francisco, Tokyo, and Paris?\"}]\n    tools = [\n        {\n            \"type\": \"function\",\n            \"function\": {\n                \"name\": \"get_current_weather\",\n                \"description\": \"Get the current weather in a given location\",\n                \"parameters\": {\n                    \"type\": \"object\",\n                    \"properties\": {\n                        \"location\": {\n                            \"type\": \"string\",\n                            \"description\": \"The city and state, e.g. San Francisco, CA\",\n                        },\n                        \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n                    },\n                    \"required\": [\"location\"],\n                },\n            },\n        }\n    ]\n\n    response = client.chat.completions.create(\n        model=\"chatglm3-6b\",\n        messages=messages,\n        tools=tools,\n        tool_choice=\"auto\",\n    )\n    if response:\n        content = response.choices[0].message.content\n        print(content)\n    else:\n        print(\"Error:\", response.status_code)\n\n\ndef simple_chat(use_stream=True):\n    messages = [\n        {\n            \"role\": \"system\",\n            \"content\": \"You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's \"\n                       \"instructions carefully. Respond using markdown.\",\n        },\n        {\n            \"role\": \"user\",\n            \"content\": \"你好，请你用生动的话语给我讲一个小故事吧\"\n        }\n    ]\n    response = client.chat.completions.create(\n        model=\"chatglm3-6b\",\n        messages=messages,\n        stream=use_stream,\n        max_tokens=256,\n        temperature=0.8,\n        presence_penalty=1.1,\n        top_p=0.8)\n    if response:\n        if use_stream:\n            for chunk in response:\n                print(chunk.choices[0].delta.content)\n        else:\n            content = response.choices[0].message.content\n            print(content)\n    else:\n        print(\"Error:\", response.status_code)\n\n\nif __name__ == \"__main__\":\n    simple_chat(use_stream=False)\n    simple_chat(use_stream=True)\n"
  },
  {
    "path": "Intel_device_demo/ipex_llm_cpu_demo/utils.py",
    "content": "import gc\nimport json\nimport torch\nfrom transformers import PreTrainedModel, PreTrainedTokenizer\nfrom transformers.generation.logits_process import LogitsProcessor\nfrom typing import Union, Tuple\n\n\nclass InvalidScoreLogitsProcessor(LogitsProcessor):\n    def __call__(\n            self, input_ids: torch.LongTensor, scores: torch.FloatTensor\n    ) -> torch.FloatTensor:\n        if torch.isnan(scores).any() or torch.isinf(scores).any():\n            scores.zero_()\n            scores[..., 5] = 5e4\n        return scores\n\n\ndef process_response(output: str, use_tool: bool = False) -> Union[str, dict]:\n    content = \"\"\n    for response in output.split(\"<|assistant|>\"):\n        metadata, content = response.split(\"\\n\", maxsplit=1)\n        if not metadata.strip():\n            content = content.strip()\n            content = content.replace(\"[[训练时间]]\", \"2023年\")\n        else:\n            if use_tool:\n                content = \"\\n\".join(content.split(\"\\n\")[1:-1])\n\n                def tool_call(**kwargs):\n                    return kwargs\n\n                parameters = eval(content)\n                content = {\n                    \"name\": metadata.strip(),\n                    \"arguments\": json.dumps(parameters, ensure_ascii=False)\n                }\n            else:\n                content = {\n                    \"name\": metadata.strip(),\n                    \"content\": content\n                }\n    return content\n\n\n@torch.inference_mode()\ndef generate_stream_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict):\n    messages = params[\"messages\"]\n    tools = params[\"tools\"]\n    temperature = float(params.get(\"temperature\", 1.0))\n    repetition_penalty = float(params.get(\"repetition_penalty\", 1.0))\n    top_p = float(params.get(\"top_p\", 1.0))\n    max_new_tokens = int(params.get(\"max_tokens\", 256))\n    echo = params.get(\"echo\", True)\n    messages = process_chatglm_messages(messages, tools=tools)\n    query, role = messages[-1][\"content\"], messages[-1][\"role\"]\n\n    inputs = tokenizer.build_chat_input(query, history=messages[:-1], role=role)\n    inputs = inputs.to(model.device)\n    input_echo_len = len(inputs[\"input_ids\"][0])\n\n    if input_echo_len >= model.config.seq_length:\n        print(f\"Input length larger than {model.config.seq_length}\")\n\n    eos_token_id = [\n        tokenizer.eos_token_id,\n        tokenizer.get_command(\"<|user|>\"),\n    ]\n\n    gen_kwargs = {\n        \"max_new_tokens\": max_new_tokens,\n        \"do_sample\": True if temperature > 1e-5 else False,\n        \"top_p\": top_p,\n        \"repetition_penalty\": repetition_penalty,\n        \"logits_processor\": [InvalidScoreLogitsProcessor()],\n    }\n    if temperature > 1e-5:\n        gen_kwargs[\"temperature\"] = temperature\n\n    total_len = 0\n    for total_ids in model.stream_generate(**inputs, eos_token_id=eos_token_id, **gen_kwargs):\n        total_ids = total_ids.tolist()[0]\n        total_len = len(total_ids)\n        if echo:\n            output_ids = total_ids[:-1]\n        else:\n            output_ids = total_ids[input_echo_len:-1]\n\n        response = tokenizer.decode(output_ids)\n        if response and response[-1] != \"�\":\n            response, stop_found = apply_stopping_strings(response, [\"<|observation|>\"])\n\n            yield {\n                \"text\": response,\n                \"usage\": {\n                    \"prompt_tokens\": input_echo_len,\n                    \"completion_tokens\": total_len - input_echo_len,\n                    \"total_tokens\": total_len,\n                },\n                \"finish_reason\": \"function_call\" if stop_found else None,\n            }\n\n            if stop_found:\n                break\n\n    # Only last stream result contains finish_reason, we set finish_reason as stop\n    ret = {\n        \"text\": response,\n        \"usage\": {\n            \"prompt_tokens\": input_echo_len,\n            \"completion_tokens\": total_len - input_echo_len,\n            \"total_tokens\": total_len,\n        },\n        \"finish_reason\": \"stop\",\n    }\n    yield ret\n\n    gc.collect()\n    torch.cuda.empty_cache()\n\n\ndef process_chatglm_messages(messages, tools=None):\n    _messages = messages\n    messages = []\n    if tools:\n        messages.append(\n            {\n                \"role\": \"system\",\n                \"content\": \"Answer the following questions as best as you can. You have access to the following tools:\",\n                \"tools\": tools\n            }\n        )\n\n    for m in _messages:\n        role, content, func_call = m.role, m.content, m.function_call\n        if role == \"function\":\n            messages.append(\n                {\n                    \"role\": \"observation\",\n                    \"content\": content\n                }\n            )\n\n        elif role == \"assistant\" and func_call is not None:\n            for response in content.split(\"<|assistant|>\"):\n                metadata, sub_content = response.split(\"\\n\", maxsplit=1)\n                messages.append(\n                    {\n                        \"role\": role,\n                        \"metadata\": metadata,\n                        \"content\": sub_content.strip()\n                    }\n                )\n        else:\n            messages.append({\"role\": role, \"content\": content})\n    return messages\n\n\ndef generate_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict):\n    for response in generate_stream_chatglm3(model, tokenizer, params):\n        pass\n    return response\n\n\ndef apply_stopping_strings(reply, stop_strings) -> Tuple[str, bool]:\n    stop_found = False\n    for string in stop_strings:\n        idx = reply.find(string)\n        if idx != -1:\n            reply = reply[:idx]\n            stop_found = True\n            break\n\n    if not stop_found:\n        # If something like \"\\nYo\" is generated just before \"\\nYou: is completed, trim it\n        for string in stop_strings:\n            for j in range(len(string) - 1, 0, -1):\n                if reply[-j:] == string[:j]:\n                    reply = reply[:-j]\n                    break\n            else:\n                continue\n\n            break\n\n    return reply, stop_found\n"
  },
  {
    "path": "Intel_device_demo/openvino_demo/README.md",
    "content": "\n# 使用 OpenVINO 部署ChatGLM3-6B 模型\n\n[OpenVINO](https://www.intel.com/content/www/us/en/developer/tools/openvino-toolkit/overview.html) 是 Intel 为深度学习推理而设计的开源工具包。它可以帮助开发者优化模型，提高推理性能，减少模型的内存占用。本示例将展示如何使用 OpenVINO 部署 ChatGLM3。\n\n你需要克隆本仓库，然后按照以下步骤进行操作来将模型转换你的 OpenVINO IR 模型，随后进行推理。\n## 1. 环境配置\n\n首先，克隆OpenVINO GLM3 推理仓库并安装依赖。\n\n```bash\ngit clone https://github.com/OpenVINO-dev-contest/chatglm3.openvino.git\ncd chatglm3.openvino\n```\n\n接着，我们推荐您新建一个虚拟环境，然后按照以下安装依赖。\n\n```\npython3 -m venv openvino_env\n\nsource openvino_env/bin/activate\n\npython3 -m pip install --upgrade pip\n\npip install wheel setuptools\n\npip install -r requirements.txt\n```\n\n## 2. 转换模型\n\n由于需要将Huggingface模型转换为OpenVINO IR模型，因此您需要下载模型并转换。\n\n```\npython3 convert.py --model_id THUDM/chatglm3-6b --output {your_path}/chatglm3-6b \n```\n\n### 可以选择的参数\n\n* `--model_id` - 模型所在目录的路径（绝对路径）。\n* `--output` - 转换后模型保存的地址\n\n## 3. 量化模型（非必须）\n\n```\npython3 quantize.py --model_path {your_path}/chatglm3-6b --precision int4 --output {your_path}/chatglm3-6b-int4\n```\n\n### 可以选择的参数\n\n* `--model_path` - OpenVINO IR 模型所在目录的路径。\n* `-- precision` - 量化精度：int8 或 int4。\n* `--output` - 保存模型的路径。\n\n## 4. 运行 ChatGLM3 模型\n\n```\npython3 chat.py --model_path {your_path}/chatglm3-6b --max_sequence_length 4096 --device CPU\n```\n\n### 可以选择的参数\n\n* `--model_path` - OpenVINO IR 模型所在目录的路径。\n* `--max_sequence_length` - 输出标记的最大大小。\n* `--device` - 运行推理的设备。\n\n## 例子\n\n```\n用户: 你好\nChatGLM3-6B-OpenVINO: 你好！有什么我可以帮助你的吗？\n\n用户: 你是谁？     \nChatGLM3-6B-OpenVINO: 我是一个名为ChatGLM3-6B的人工智能助手，是由清华大学KEG实验室和智谱AI 公司于2023 年共同训练的语言模型开发而成。我的任务是针对用户的问题和要求提供适当的答复和支持。\n\n用户: 请给我讲一个有趣的故事\nChatGLM3-6B-OpenVINO: 从前，有一个名叫小明的小男孩，他是一个非常喜欢动物的人。有一天，他在森林里散步时，发现了一个非常漂亮的小鸟。小鸟受伤了，无法飞行。小明非常心疼，于是决定照顾这只小鸟。小明带着小鸟回家，为它搭建了一个小小的巢穴，并找来了一些软草和食物。每天，他都会给小鸟喂食，并为它换水。渐渐地，小鸟的伤势好了起来，开始在小明的家里飞来飞去，它们成了非常好的朋友。然而，一天，小明的父母告诉他，他们必须把小明养的小鸟送到森林里去。小明非常伤心，因为他已经和小鸟成为了好朋友。但是，他的父母告诉他，小鸟在森林里会更加自由自在，而且他也可以继续观看小鸟在森林中的生活。于是，小明和他的父母一起将小鸟送到了森林中。小鸟非常高兴，因为它又可以飞行了，并且还有许多其他的小动物朋友。小明也感到非常开心，因为他知道，即使不能一直拥有小鸟，他仍然可以欣赏到它们在自然中的美丽。从此以后，小明常常来到森林中，寻找小鸟。\n\n用户: 请给这个故事起一个标题\nChatGLM3-6B-OpenVINO: 《友谊的力量：小明与小鸟的森林冒险》\n```\n\n## 常见问题\n1. 为什么倒入本地模型还会报 huggingface 链接错误\n   - 降级 transformers 库到 4.37.2 版本\n\n2. 需要安装 OpenVINO C++ 推理引擎吗\n   - 不需要\n\n3. 一定要使用 Intel 的硬件吗？\n   - 我们仅在 Intel 设备上尝试，我们推荐使用x86架构的英特尔设备，包括但不限制于：\n   - 英特尔的CPU，包括个人电脑CPU 和服务器CPU。\n   - 英特尔的独立显卡。例如：ARC A770 显卡。"
  },
  {
    "path": "Intel_device_demo/openvino_demo/openvino_cli_demo.py",
    "content": "import argparse\nfrom typing import List, Tuple\nfrom threading import Thread\nimport torch\nfrom optimum.intel.openvino import OVModelForCausalLM\nfrom transformers import (AutoTokenizer, AutoConfig,\n                          TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria)\n\n\ndef parse_text(text):\n    lines = text.split(\"\\n\")\n    lines = [line for line in lines if line != \"\"]\n    count = 0\n    for i, line in enumerate(lines):\n        if \"```\" in line:\n            count += 1\n            items = line.split('`')\n            if count % 2 == 1:\n                lines[i] = f'<pre><code class=\"language-{items[-1]}\">'\n            else:\n                lines[i] = f'<br></code></pre>'\n        else:\n            if i > 0:\n                if count % 2 == 1:\n                    line = line.replace(\"`\", \"\\`\")\n                    line = line.replace(\"<\", \"&lt;\")\n                    line = line.replace(\">\", \"&gt;\")\n                    line = line.replace(\" \", \"&nbsp;\")\n                    line = line.replace(\"*\", \"&ast;\")\n                    line = line.replace(\"_\", \"&lowbar;\")\n                    line = line.replace(\"-\", \"&#45;\")\n                    line = line.replace(\".\", \"&#46;\")\n                    line = line.replace(\"!\", \"&#33;\")\n                    line = line.replace(\"(\", \"&#40;\")\n                    line = line.replace(\")\", \"&#41;\")\n                    line = line.replace(\"$\", \"&#36;\")\n                lines[i] = \"<br>\" + line\n    text = \"\".join(lines)\n    return text\n\n\nclass StopOnTokens(StoppingCriteria):\n    def __init__(self, token_ids):\n        self.token_ids = token_ids\n\n    def __call__(\n            self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs\n    ) -> bool:\n        for stop_id in self.token_ids:\n            if input_ids[0][-1] == stop_id:\n                return True\n        return False\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(add_help=False)\n    parser.add_argument('-h',\n                        '--help',\n                        action='help',\n                        help='Show this help message and exit.')\n    parser.add_argument('-m',\n                        '--model_path',\n                        required=True,\n                        type=str,\n                        help='Required. model path')\n    parser.add_argument('-l',\n                        '--max_sequence_length',\n                        default=256,\n                        required=False,\n                        type=int,\n                        help='Required. maximun length of output')\n    parser.add_argument('-d',\n                        '--device',\n                        default='CPU',\n                        required=False,\n                        type=str,\n                        help='Required. device for inference')\n    args = parser.parse_args()\n    model_dir = args.model_path\n\n    ov_config = {\"PERFORMANCE_HINT\": \"LATENCY\",\n                 \"NUM_STREAMS\": \"1\", \"CACHE_DIR\": \"\"}\n\n    tokenizer = AutoTokenizer.from_pretrained(\n        model_dir, trust_remote_code=True)\n\n    print(\"====Compiling model====\")\n    ov_model = OVModelForCausalLM.from_pretrained(\n        model_dir,\n        device=args.device,\n        ov_config=ov_config,\n        config=AutoConfig.from_pretrained(model_dir, trust_remote_code=True),\n        trust_remote_code=True,\n    )\n\n    streamer = TextIteratorStreamer(\n        tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True\n    )\n    stop_tokens = [0, 2]\n    stop_tokens = [StopOnTokens(stop_tokens)]\n\n    def convert_history_to_token(history: List[Tuple[str, str]]):\n\n        messages = []\n        for idx, (user_msg, model_msg) in enumerate(history):\n            if idx == len(history) - 1 and not model_msg:\n                messages.append({\"role\": \"user\", \"content\": user_msg})\n                break\n            if user_msg:\n                messages.append({\"role\": \"user\", \"content\": user_msg})\n            if model_msg:\n                messages.append({\"role\": \"assistant\", \"content\": model_msg})\n\n        model_inputs = tokenizer.apply_chat_template(messages,\n                                                     add_generation_prompt=True,\n                                                     tokenize=True,\n                                                     return_tensors=\"pt\")\n        return model_inputs\n\n    history = []\n    print(\"====Starting conversation====\")\n    while True:\n        input_text = input(\"用户: \")\n        if input_text.lower() == 'stop':\n            break\n\n        if input_text.lower() == 'clear':\n            history = []\n            print(\"AI助手: 对话历史已清空\")\n            continue\n\n        print(\"ChatGLM3-6B-OpenVINO:\", end=\" \")\n        history = history + [[parse_text(input_text), \"\"]]\n        model_inputs = convert_history_to_token(history)\n        generate_kwargs = dict(\n            input_ids=model_inputs,\n            max_new_tokens=args.max_sequence_length,\n            temperature=0.1,\n            do_sample=True,\n            top_p=1.0,\n            top_k=50,\n            repetition_penalty=1.1,\n            streamer=streamer,\n            stopping_criteria=StoppingCriteriaList(stop_tokens)\n        )\n\n        t1 = Thread(target=ov_model.generate, kwargs=generate_kwargs)\n        t1.start()\n\n        partial_text = \"\"\n        for new_text in streamer:\n            new_text = new_text\n            print(new_text, end=\"\", flush=True)\n            partial_text += new_text\n        print(\"\\n\")\n        history[-1][1] = partial_text"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright 2024 ChatGLM team @ Zhipu AI\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "MODEL_LICENSE",
    "content": "The ChatGLM3-6B License\n\n1. 定义\n\n“许可方”是指分发其软件的 ChatGLM3-6B 模型团队。\n\n“软件”是指根据本许可提供的 ChatGLM3-6B 模型参数。\n\n2. 许可授予\n\n根据本许可的条款和条件，许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可。\n本许可允许您免费使用本仓库中的所有开源模型进行学术研究，对于希望将模型用于商业目的的用户，需在[这里](https://open.bigmodel.cn/mla/form)完成登记。\n经过登记的用户可以免费使用本模型进行商业活动，但必须遵守本许可的所有条款和条件。\n上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。\n\n3.限制\n\n您不得出于任何军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。\n\n您不得利用本软件从事任何危害国家安全和国家统一，危害社会公共利益及公序良俗，侵犯他人商业秘密、知识产权、名誉权、肖像权、财产权等权益的行为。\n\n您在使用中应遵循使用地所适用的法律法规政策、道德规范等要求。\n\n4.免责声明\n\n本软件“按原样”提供，不提供任何明示或暗示的保证，包括但不限于对适销性、特定用途的适用性和非侵权性的保证。 在任何情况下，作者或版权持有人均不对任何索赔、损害或其他责任负责，无论是在合同诉讼、侵权行为还是其他方面，由软件或软件的使用或其他交易引起、由软件引起或与之相关 软件。\n\n5. 责任限制\n\n除适用法律禁止的范围外，在任何情况下且根据任何法律理论，无论是基于侵权行为、疏忽、合同、责任或其他原因，任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 或间接损害，或任何其他商业损失，即使许可人已被告知此类损害的可能性。\n\n6.争议解决\n\n本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。\n\n请注意，许可证可能会更新到更全面的版本。 有关许可和版权的任何问题，请通过 license@zhipuai.cn 与我们联系。\n\n1. Definitions\n\n“Licensor” means the ChatGLM3-6B Model Team that distributes its Software.\n\n“Software” means the ChatGLM3-6B model parameters made available under this license.\n\n2. License Grant\n\nUnder the terms and conditions of this license, the Licensor hereby grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license.\nThis license permits you to use all open-source models in this repository for academic research free. Users who wish to use the models for commercial purposes must register [here](https://open.bigmodel.cn/mla/form).\nRegistered users may use the models for commercial activities free of charge, but must comply with all terms and conditions of this license.\nThe license notice shall be included in all copies or substantial portions of the Software.\n\n3. Restrictions\nYou are not allowed to use, copy, modify, merge, publish, distribute, copy or create all or part of the derivative works of this software for any military or illegal purposes.\nYou are not allowed to use this software to engage in any behavior that endangers national security and unity, endangers social public interests and public order, infringes on the rights and interests of others such as trade secrets, intellectual property rights, reputation rights, portrait rights, and property rights.\nYou should comply with the applicable laws, regulations, policies, ethical standards, and other requirements in the place of use during use.\n\n4. Disclaimer\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n\n5. Limitation of Liability\n\nEXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.\n\n6. Dispute Resolution\n\nThis license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.\n\nNote that the license is subject to update to a more comprehensive version.  For any questions related to the license and copyright, please contact us at license@zhipuai.cn.\n"
  },
  {
    "path": "PROMPT.md",
    "content": "## ChatGLM3 对话格式\n为了避免用户输入的注入攻击，以及统一 Code Interpreter，Tool & Agent 等任务的输入，ChatGLM3 采用了全新的对话格式。\n\n### 规定\n#### 整体结构\nChatGLM3 对话的格式由若干对话组成，其中每个对话包含对话头和内容，一个典型的多轮对话结构如下\n```text\n<|system|>\nYou are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.\n<|user|>\nHello\n<|assistant|>\nHello, I'm ChatGLM3. What can I assist you today?\n```\n**实际中每轮对话内容并不一定以换行符结尾，这里只是为了美观，下同**\n\n#### 对话头\n对话头占完整的一行，格式为\n```text\n<|role|>{metadata}\n```\n其中 `<|role|>` 部分使用 special token 表示，无法从文本形式被 tokenizer 编码以防止注入。metadata 部分采用纯文本表示，为可选内容。\n* `<|system|>`：系统信息，设计上可穿插于对话中，**但目前规定仅可以出现在开头**\n* `<|user|>`：用户\n  - 不会连续出现多个来自 `<|user|>` 的信息\n* `<|assistant|>`：AI 助手\n  - 在出现之前必须有一个来自 `<|user|>` 的信息\n* `<|observation|>`：外部的返回结果\n  - 必须在 `<|assistant|>` 的信息之后\n\n### 样例场景\n\n为提升可读性，下列样例场景中表示角色的 special token 前均额外添加了一个换行符。实际使用及 tokenizer 实现中均无需额外添加这一换行。\n\n#### 多轮对话\n* 有且仅有 `<|user|>`、`<|assistant|>`、`<|system|>` 三种 role\n```text\n<|system|>\nYou are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.\n<|user|>\nHello\n<|assistant|>\nHello, I'm ChatGLM3. What can I assist you today?\n```\n\n#### 工具调用\n````\n<|system|>\nAnswer the following questions as best as you can. You have access to the following tools:\n[\n    {\n        \"name\": \"get_current_weather\",\n        \"description\": \"Get the current weather in a given location\",\n        \"parameters\": {\n            \"type\": \"object\",\n            \"properties\": {\n                \"location\": {\n                    \"type\": \"string\",\n                    \"description\": \"The city and state, e.g. San Francisco, CA\",\n                },\n                \"unit\": {\"type\": \"string\"},\n            },\n            \"required\": [\"location\"],\n        },\n    }\n]\n<|user|>\n今天北京的天气怎么样？\n<|assistant|>\n好的，让我们来查看今天的天气\n<|assistant|>get_current_weather\n```python\ntool_call(location=\"beijing\", unit=\"celsius\")\n```\n<|observation|>\n{\"temperature\": 22}\n<|assistant|>\n根据查询结果，今天北京的气温为 22 摄氏度。\n````\n\n#### 代码执行\n* 有 `<|user|>`、`<|assistant|>`、`<|system|>`、`<|observation|>` 四种 role。其中 `<|assistant|>` 的 metadata 只有 interpreter。\n\n`````text\n<|system|>\n你是一位智能AI助手，你叫ChatGLM3，你连接着一台电脑，但请注意不能联网。在使用Python解决任务时，你可以运行代码并得到结果，如果运行结果有错误，你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件，文件默认存储路径是/mnt/data/。\n<|user|>\n#File: /mnt/data/metadata.jsonl\n#Size: 35380\n#File uploaded\n文件中是否存在缺失值或异常值？\n<|assistant|>\n首先，我将读取您提供的文件，并查看其内容以确定是否存在缺失值或异常值。\n\n我们从加载文件开始：\n<|assistant|>interpreter\n```python\nimport json\n\n# Load the uploaded file\ndata = []\nwith open(\"/mnt/data/metadata.jsonl\", 'r') as file:\n    for line in file:\n        data.append(json.loads(line))\n\n# Display the first few records to get a sense of the data\ndata[:5]\n```\n<|observation|>\n```result\n[{'file_name': 'aceinthehole.png',\n  'name': 'Ace in the Hole',\n  'type': 'survivor',\n  'description': 'Lady Luck always seems to be throwing something good your way.'},\n {'file_name': 'adrenaline.png',\n  'name': 'Adrenaline',\n  'type': 'survivor',\n  'description': 'You are fuelled by unexpected energy when on the verge of escape.'},\n {'file_name': 'aftercare.png',\n  'name': 'Aftercare',\n  'type': 'survivor',\n  'description': \"Unlocks potential in one's aura reading ability.\"},\n {'file_name': 'agitation.png',\n  'name': 'Agitation',\n  'type': 'killer',\n  'description': 'You get excited in anticipation of hooking your prey.'},\n {'file_name': 'alert.png',\n  'name': 'Alert',\n  'type': 'survivor',\n  'description': 'Your acute senses are on high alert.'}]\n```\n<|assistant|>\n该文件看起来包含有关某些条目的元数据，每个条目有以下字段：\n\n\n- `file_name`: 文件名称\n- `name`: 名称\n- `type`: 类型（例如 \"survivor\" 或 \"killer\"）\n- `description`: 描述\n\n接下来，我将检查这些字段是否存在缺失值。\n<|assistant|>interpreter\n```python\n# Check for missing values in each column\nmissing_values = {key: sum(1 for item in data if key not in item) for key in data[0].keys()}\n\nmissing_values\n```\n<|observation|>\n```result\n{'file_name': 0, 'name': 0, 'type': 0, 'description': 0}\n```\n<|assistant|>\n每个字段都没有缺失值。这意味着每个条目都完整。\n\n接下来，我们可以检查是否存在任何潜在的异常值。这通常需要更深入的分析和对数据的了解，但我可以从一些基本的统计和分布开始进行检查。为了这样做，我们可以查看“type”字段的值分布，看看是否有任何意外的类别。\n<|assistant|>interpreter\n```python\n# Check the distribution of values in the 'type' column\ntype_distribution = {key: sum(1 for item in data if item['type'] == key) for key in set(item['type'] for item in data)}\n\ntype_distribution\n```\n<|observation|>\n```result\n{'survivor': 116, 'killer': 103}\n```\n<|assistant|>\n“type”字段有两个值：“survivor”和“killer”。其中，“survivor”有116个条目，而“killer”有103个条目。\n\n从这个初步的检查中，我们没有发现任何明显的异常值。\n\n如果您有任何特定的标准或其他字段需要检查，我可以为您提供更详细的分析。您还希望我检查什么呢？\n<|user|>\n帮我画一个爱心\n<|assistant|>interpreter\n```python\nimport numpy as np\nimport matplotlib.pyplot as plt\n\n# 爱心形状的参数方程\ndef heart(t):\n    x = 16 * np.sin(t) ** 3\n    y = 13 * np.cos(t) - 5 * np.cos(2 * t) - 2 * np.cos(3 * t) - np.cos(4 * t)\n    return x, y\n\nt = np.linspace(0, 2 * np.pi, 1000)\nx, y = heart(t)\n\nplt.figure(figsize=(6, 6))\nplt.plot(x, y, color='red')\nplt.axis('equal')\nplt.axis('off')\nplt.show()\n```\n<|observation|>\n```result\n[Image]\n```\n<|assistant|>\n这是一个爱心形状。我使用了参数方程来描述这个形状，并使用matplotlib进行了绘制。如果您有任何其他需求或问题，请随时告诉我。\n<|user|>  # End\n````"
  },
  {
    "path": "PROMPT_en.md",
    "content": "## ChatGLM3 Chat Format\nTo avoid injection attacks from user input, and to unify the input of Code Interpreter, Tool & Agent and other tasks, ChatGLM3 adopts a brand-new dialogue format.\n\n### Regulations\n#### Overall Structure\nThe format of the ChatGLM3 dialogue consists of several conversations, each of which contains a dialogue header and content. A typical multi-turn dialogue structure is as follows:\n```text\n<|system|>\nYou are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.\n<|user|>\nHello\n<|assistant|>\nHello, I'm ChatGLM3. What can I assist you today?\n```\n\n#### Chat Header\nThe chat header occupies a complete line, formatted as:\n```text\n<|role|>{metadata}\n```\nWhere `<|role|>` part is represented in a special token,  which can’t be encoded by the tokenizer from the text form to prevent injection attacks. The `metadata` part is represented in plain texts and is optional content.\n* `<|system|>`: System information, which can be interspersed in the dialogue in design, **but currently only appears at the beginning**\n* `<|user|>`: User\n  - Multiple messages from `<|user|>` will not appear continuously\n* `<|assistant|>`: AI assistant\n  - There must be a message from `<|user|>` before it appears\n* `<|observation|>`: External return result\n  - Must be after the message from `<|assistant|>`\n\n### Example Scenarios\n\nFor better readability, an extra `\\n` is added before each role special token. This extra `\\n` should not be added in actual use and tokenizer implementation.\n\n#### Multi-turn Dialogue\n* There are only three roles: `<|user|>`, `<|assistant|>`, and `<|system|>`.\n```text\n<|system|>\nYou are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.\n<|user|>\nHello\n<|assistant|>\nHello, I'm ChatGLM3. What can I assist you today?\n```\n\n#### Tool Calling\n````\n<|system|>\nAnswer the following questions as best as you can. You have access to the following tools:\n[\n    {\n        \"name\": \"get_current_weather\",\n        \"description\": \"Get the current weather in a given location\",\n        \"parameters\": {\n            \"type\": \"object\",\n            \"properties\": {\n                \"location\": {\n                    \"type\": \"string\",\n                    \"description\": \"The city and state, e.g. San Francisco, CA\",\n                },\n                \"unit\": {\"type\": \"string\"},\n            },\n            \"required\": [\"location\"],\n        },\n    }\n]\n<|user|>\nWhat's the weather in Beijing today?\n<|assistant|>\nOkay, let's look up the weather in Bejing today.\n<|assistant|>get_current_weather\n```python\ntool_call(location=\"beijing\", unit=\"celsius\")\n```\n<|observation|>\n{\"temperature\": 22}\n<|assistant|>\nAccording to the query results, the temperature in Beijing today is 22 degrees Celsius.\n````\n\n#### Code Execution\n* Four available roles: `<|user|>`, `<|assistant|>`, `<|system|>`, and `<|observation|>`, where `metadata` of `<|assistant|>` can only be interpreter.\n\n`````text\n<|system|>\nYou are an intelligent AI assistant, named ChatGLM3. You are connected to a computer, but please note that you cannot connect to the internet. When using Python to solve tasks, you can run the code and get the results. If there are errors in the running results, you need to improve the code as much as possible. You can process files uploaded by users to your computer, and the default storage path for files is /mnt/data/.\n<|user|>\n#File: /mnt/data/metadata.jsonl\n#Size: 35380\n#File uploaded\nAre there missing or abnormal values in the file?\n<|assistant|>\nFirstly, I will read the file you provided and review its contents to determine if there are missing or abnormal values.\n\nWe start by loading the file:\n<|assistant|>interpreter\n```python\nimport json\n\n# Load the uploaded file\ndata = []\nwith open(\"/mnt/data/metadata.jsonl\", 'r') as file:\n    for line in file:\n        data.append(json.loads(line))\n\n# Display the first few records to get a sense of the data\ndata[:5]\n```\n<|observation|>\n```result\n[{'file_name': 'aceinthehole.png',\n  'name': 'Ace in the Hole',\n  'type': 'survivor',\n  'description': 'Lady Luck always seems to be throwing something good your way.'},\n {'file_name': 'adrenaline.png',\n  'name': 'Adrenaline',\n  'type': 'survivor',\n  'description': 'You are fuelled by unexpected energy when on the verge of escape.'},\n {'file_name': 'aftercare.png',\n  'name': 'Aftercare',\n  'type': 'survivor',\n  'description': \"Unlocks potential in one's aura reading ability.\"},\n {'file_name': 'agitation.png',\n  'name': 'Agitation',\n  'type': 'killer',\n  'description': 'You get excited in anticipation of hooking your prey.'},\n {'file_name': 'alert.png',\n  'name': 'Alert',\n  'type': 'survivor',\n  'description': 'Your acute senses are on high alert.'}]\n```\n<|assistant|>\nThis file appears to contain metadata about certain items, each with the following fields:\n\n- `file_name`\n- `name`\n- `type` (e.g. \"survivor\" or \"killer\")\n- `description`\n\nNext, I will check for missing values in these fields.\n<|assistant|>interpreter\n```python\n# Check for missing values in each column\nmissing_values = {key: sum(1 for item in data if key not in item) for key in data[0].keys()}\n\nmissing_values\n```\n<|observation|>\n```result\n{'file_name': 0, 'name': 0, 'type': 0, 'description': 0}\n```\n<|assistant|>\nEach field has no missing values. This means that each entry is complete.\n\nNext, we can check for any potential outliers. This usually requires deeper analysis and understanding of the data, but I can start by examining some basic statistics and distributions. To do this, we can check the value distribution of the \"type\" field to see if there are any unexpected categories.\n<|assistant|>interpreter\n```python\n# Check the distribution of values in the 'type' column\ntype_distribution = {key: sum(1 for item in data if item['type'] == key) for key in set(item['type'] for item in data)}\n\ntype_distribution\n```\n<|observation|>\n```result\n{'survivor': 116, 'killer': 103}\n```\n<|assistant|>\nThe 'type' field has two values: 'survivor' and 'killer'. Among them, 'survivor' has 116 entries, while 'killer' has 103 entries.\n\nFrom this preliminary examination, we did not find any obvious abnormal values.\n\nIf you have any specific standards or other fields that need to be checked, I can provide you with a more detailed analysis. What else do you want me to check?\n<|user|>\nHelp me draw a heart\n<|assistant|>interpreter\n```python\nimport numpy as np\nimport matplotlib.pyplot as plt\n\n# The parametric equation for the shape of a heart\ndef heart(t):\n    x = 16 * np.sin(t) ** 3\n    y = 13 * np.cos(t) - 5 * np.cos(2 * t) - 2 * np.cos(3 * t) - np.cos(4 * t)\n    return x, y\n\nt = np.linspace(0, 2 * np.pi, 1000)\nx, y = heart(t)\n\nplt.figure(figsize=(6, 6))\nplt.plot(x, y, color='red')\nplt.axis('equal')\nplt.axis('off')\nplt.show()\n```\n<|observation|>\n```result\n[Image]\n```\n<|assistant|>\nThis is a shape of heart. I used parametric equations to describe this shape and drew it using Matplotlib. If you have any other needs or questions, please feel free to let me know.\n<|user|>  # End\n````\n"
  },
  {
    "path": "README.md",
    "content": "# ChatGLM3\n\n<p align=\"center\">\n📄<a href=\"https://arxiv.org/pdf/2406.12793\" target=\"_blank\"> Report </a> • 🤗 <a href=\"https://huggingface.co/THUDM/chatglm3-6b\" target=\"_blank\">HF Repo</a> • 🤖 <a href=\"https://modelscope.cn/models/ZhipuAI/chatglm3-6b\" target=\"_blank\">ModelScope</a> • 🟣 <a href=\"https://www.wisemodel.cn/models/ZhipuAI/chatglm3-6b\" target=\"_blank\">WiseModel</a> • 📔 <a href=\"https://lslfd0slxc.feishu.cn/wiki/WvQbwIJ9tiPAxGk8ywDck6yfnof\" target=\"_blank\">Document</a> •  🧰 <a href=\"https://openxlab.org.cn/models/hot/THUDM\" target=\"_blank\">OpenXLab</a> • 🐦 <a href=\"https://twitter.com/thukeg\" target=\"_blank\">Twitter</a><br>\n</p>\n<p align=\"center\">\n    👋 加入我们的 <a href=\"https://discord.gg/fK2dz4bg\" target=\"_blank\">Discord</a> 和 <a href=\"resources/WECHAT.md\" target=\"_blank\">微信</a>\n</p>\n<p align=\"center\">\n📍在 <a href=\"https://www.chatglm.cn\">chatglm.cn</a> 体验更大规模的 ChatGLM 模型。\n</p>\n\n[Read this in English.](./README_en.md)\n\n📔 关于`ChatGLM3-6B` 更为详细的使用信息，可以参考\n\n+ [ChatGLM3 开放技术文档](https://lslfd0slxc.feishu.cn/wiki/WvQbwIJ9tiPAxGk8ywDck6yfnof?from=from_copylink)\n+ [Bilibili video](https://www.bilibili.com/video/BV1uC4y1J7yA)\n+ [YouTube video](https://www.youtube.com/watch?v=Pw9PB6R7ORA)\n\n## GLM-4 开源模型和API\n\n我们已经发布最新的 **GLM-4** 模型，该模型在多个指标上有了新的突破，您可以在以下两个渠道体验我们的最新模型。\n\n+ [GLM-4 开源模型](https://github.com/THUDM/GLM-4) 我们已经开源了 GLM-4-9B 系列模型，在各项指标的测试上有明显提升，欢迎尝试。\n+ [智谱清言](https://chatglm.cn/main/detail?fr=ecology_x) 体验最新版 GLM-4，包括 **GLMs，All tools**等功能。\n+ [API平台](https://open.bigmodel.cn/?utm_campaign=open&_channel_track_key=OWTVNma9) 新一代 API 平台已经上线，您可以直接在\n  API\n  平台上体验 `GLM-4-0520`、`GLM-4-air`、`GLM-4-airx`、`GLM-4-flash`、`GLM-4`、`GLM-3-Turbo`、`CharacterGLM-3`，`CogView-3`\n  等新模型。\n  其中`GLM-4`、`GLM-3-Turbo`两个模型支持了 `System Prompt`、`Function Call`、 `Retrieval`、`Web_Search`等新功能，欢迎体验。\n\n+ [GLM-4 API 开源教程](https://github.com/MetaGLM/glm-cookbook/) GLM-4 API教程和基础应用，欢迎尝试。\n  API相关问题可以在本开源教程疑问，或者使用 [GLM-4 API AI助手](https://open.bigmodel.cn/shareapp/v1/?share_code=sQwt5qyqYVaNh1O_87p8O)\n  来获得常见问题的帮助。\n\n-----\n\n## ChatGLM3 介绍\n\n**ChatGLM3** 是智谱AI和清华大学 KEG 实验室联合发布的对话预训练模型。ChatGLM3-6B 是 ChatGLM3\n系列中的开源模型，在保留了前两代模型对话流畅、部署门槛低等众多优秀特性的基础上，ChatGLM3-6B 引入了如下特性：\n\n1. **更强大的基础模型：** ChatGLM3-6B 的基础模型 ChatGLM3-6B-Base\n   采用了更多样的训练数据、更充分的训练步数和更合理的训练策略。在语义、数学、推理、代码、知识等不同角度的数据集上测评显示，*\n   *ChatGLM3-6B-Base 具有在 10B 以下的基础模型中最强的性能**。\n2. **更完整的功能支持：** ChatGLM3-6B 采用了全新设计的 [Prompt 格式](PROMPT.md)\n   ，除正常的多轮对话外。同时原生支持[工具调用](tools_using_demo/README.md)（Function Call）、代码执行（Code Interpreter）和\n   Agent 任务等复杂场景。\n3. **更全面的开源序列：** 除了对话模型 [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b)\n   外，还开源了基础模型 [ChatGLM3-6B-Base](https://huggingface.co/THUDM/chatglm3-6b-base)\n   、长文本对话模型 [ChatGLM3-6B-32K](https://huggingface.co/THUDM/chatglm3-6b-32k)\n   和进一步强化了对于长文本理解能力的 [ChatGLM3-6B-128K](https://huggingface.co/THUDM/chatglm3-6b-128k)。以上所有权重对学术研究**完全开放**\n   ，在填写 [问卷](https://open.bigmodel.cn/mla/form) 进行登记后**亦允许免费商业使用**。\n\n-----\n\nChatGLM3 开源模型旨在与开源社区一起推动大模型技术发展，恳请开发者和大家遵守 [开源协议](MODEL_LICENSE)\n，勿将开源模型和代码及基于开源项目产生的衍生物用于任何可能给国家和社会带来危害的用途以及用于任何未经过安全评估和备案的服务。目前，本项目团队未基于\n**ChatGLM3 开源模型**开发任何应用，包括网页端、安卓、苹果 iOS 及 Windows App 等应用。\n\n尽管模型在训练的各个阶段都尽力确保数据的合规性和准确性，但由于 ChatGLM3-6B\n模型规模较小，且模型受概率随机性因素影响，无法保证输出内容的准确。同时模型的输出容易被用户的输入误导。*\n*本项目不承担开源模型和代码导致的数据安全、舆情风险或发生任何模型被误导、滥用、传播、不当利用而产生的风险和责任。**\n\n## 模型列表\n                      \n|      Model       | Seq Length |                                                                                                                                              Download                                                                                                                                    |\n|:----------------:|:----------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|\n|   ChatGLM3-6B    |     8k     |        [HuggingFace](https://huggingface.co/THUDM/chatglm3-6b) \\| [ModelScope](https://modelscope.cn/models/ZhipuAI/chatglm3-6b) \\| [WiseModel](https://www.wisemodel.cn/models/ZhipuAI/chatglm3-6b)         \\| [OpenXLab](https://openxlab.org.cn/models/detail/THUDM/chatglm3-6b)      |\n| ChatGLM3-6B-Base |     8k     | [HuggingFace](https://huggingface.co/THUDM/chatglm3-6b-base) \\| [ModelScope](https://modelscope.cn/models/ZhipuAI/chatglm3-6b-base) \\| [WiseModel](https://www.wisemodel.cn/models/ZhipuAI/chatglm3-6b-base) \\| [OpenXLabl](https://openxlab.org.cn/models/detail/THUDM/chatglm3-6b-base)|\n| ChatGLM3-6B-32K  |    32k     |  [HuggingFace](https://huggingface.co/THUDM/chatglm3-6b-32k) \\| [ModelScope](https://modelscope.cn/models/ZhipuAI/chatglm3-6b-32k) \\| [WiseModel](https://www.wisemodel.cn/models/ZhipuAI/chatglm3-6b-32k)   \\| [OpenXLab](https://openxlab.org.cn/models/detail/THUDM/chatglm3-6b-32k)  |\n| ChatGLM3-6B-128K |    128k    |     [HuggingFace](https://huggingface.co/THUDM/chatglm3-6b-128k) ｜ [ModelScope](https://modelscope.cn/models/ZhipuAI/chatglm3-6b-128k)\\| [OpenXLab](https://openxlab.org.cn/models/detail/THUDM/chatglm3-6b-128k) |\n\n<br> 请注意，所有模型的最新更新都会在 Huggingface 率先发布。 ModelScope 和 WiseModel 由于没有与 Huggingface 同步，需要开发人员手动更新，可能会在\nHuggingface 更新后一段时间内同步更新。\n\n## 友情链接\n\n以下优秀开源仓库已经对 ChatGLM3-6B 模型深度支持，欢迎大家扩展学习。\n\n推理加速：\n\n* [chatglm.cpp](https://github.com/li-plus/chatglm.cpp): 类似 llama.cpp 的量化加速推理方案，实现笔记本上实时对话\n* [ChatGLM3-TPU](https://github.com/sophgo/ChatGLM3-TPU): 采用TPU加速推理方案，在算能端侧芯片BM1684X（16T@FP16，内存16G）上实时运行约7.5 token/s\n* [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/main):\n  NVIDIA开发的高性能 GPU 加速推理方案，可以参考此 [步骤](./tensorrt_llm_demo/README.md) 部署 ChatGLM3-6B 模型\n* [OpenVINO](https://github.com/openvinotoolkit): \nIntel 开发的高性能 CPU 和 GPU 加速推理方案，可以参考此 [步骤](./Intel_device_demo/openvino_demo/README.md) 部署 ChatGLM3-6B 模型\n\n高效微调：\n\n* [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory): 优秀易上手的高效微调框架。\n\n应用框架：\n\n* [LangChain-Chatchat](https://github.com/chatchat-space/Langchain-Chatchat): 基于 ChatGLM 等大语言模型与 Langchain\n  等应用框架实现，开源、可离线部署的检索增强生成(RAG)大模型知识库项目。\n\n* [BISHENG](https://github.com/dataelement/bisheng): 开源大模型应用开发平台,赋能和加速大模型应用开发落地，帮助用户以最佳体验进入下一代应用开发模式。\n\n* [RAGFlow](https://github.com/infiniflow/ragflow): RAGFlow 是一款基于深度文档理解构建的开源 RAG（Retrieval-Augmented Generation）引擎。可为各种规模的企业及个人提供一套精简的 RAG 工作流程，结合大语言模型（LLM）针对用户各类不同的复杂格式数据提供可靠的问答以及有理有据的引用。\n\n## 评测结果\n\n### 典型任务\n\n我们选取了 8 个中英文典型数据集，在 ChatGLM3-6B (base) 版本上进行了性能测试。\n\n| Model            | GSM8K | MATH | BBH  | MMLU | C-Eval | CMMLU | MBPP | AGIEval |\n|------------------|:-----:|:----:|:----:|:----:|:------:|:-----:|:----:|:-------:|\n| ChatGLM2-6B-Base | 32.4  | 6.5  | 33.7 | 47.9 |  51.7  | 50.0  |  -   |    -    |\n| Best Baseline    | 52.1  | 13.1 | 45.0 | 60.1 |  63.5  | 62.2  | 47.5 |  45.8   |\n| ChatGLM3-6B-Base | 72.3  | 25.7 | 66.1 | 61.4 |  69.0  | 67.5  | 52.4 |  53.7   |\n\n> Best Baseline 指的是截止 2023年10月27日、模型参数在 10B 以下、在对应数据集上表现最好的预训练模型，不包括只针对某一项任务训练而未保持通用能力的模型。\n\n> 对 ChatGLM3-6B-Base 的测试中，BBH 采用 3-shot 测试，需要推理的 GSM8K、MATH 采用 0-shot CoT 测试，MBPP 采用 0-shot\n> 生成后运行测例计算 Pass@1 ，其他选择题类型数据集均采用 0-shot 测试。\n\n我们在多个长文本应用场景下对 ChatGLM3-6B-32K 进行了人工评估测试。与二代模型相比，其效果平均提升了超过\n50%。在论文阅读、文档摘要和财报分析等应用中，这种提升尤为显著。此外，我们还在 LongBench 评测集上对模型进行了测试，具体结果如下表所示\n\n| Model           |  平均  | Summary | Single-Doc QA | Multi-Doc QA | Code | Few-shot | Synthetic | \n|-----------------|:----:|:-------:|:-------------:|:------------:|:----:|:--------:|:---------:|\n| ChatGLM2-6B-32K | 41.5 |  24.8   |     37.6      |     34.7     | 52.8 |   51.3   |   47.7    | \n| ChatGLM3-6B-32K | 50.2 |  26.6   |     45.8      |     46.1     | 56.2 |   61.2   |    65     |\n\n## 使用方式\n\n### 环境安装\n\n首先需要下载本仓库：\n\n```shell\ngit clone https://github.com/THUDM/ChatGLM3\ncd ChatGLM3\n```\n\n然后使用 pip 安装依赖：\n\n```\npip install -r requirements.txt\n```\n\n+ 为了保证 `torch` 的版本正确，请严格按照 [官方文档](https://pytorch.org/get-started/locally/) 的说明安装。\n\n### 综合 Demo\n\n我们提供了一个集成以下三种功能的综合 Demo，运行方法请参考[综合 Demo](composite_demo/README.md)\n\n- Chat: 对话模式，在此模式下可以与模型进行对话。\n- Tool: 工具模式，模型除了对话外，还可以通过工具进行其他操作。\n\n<img src=\"resources/tool.png\" width=\"400\">\n\n- Code Interpreter: 代码解释器模式，模型可以在一个 Jupyter 环境中执行代码并获取结果，以完成复杂任务。\n\n<img src=\"resources/heart.png\" width=\"400\">\n\n### 代码调用\n\n可以通过如下代码调用 ChatGLM 模型来生成对话：\n\n```\n>> from transformers import AutoTokenizer, AutoModel\n>> tokenizer = AutoTokenizer.from_pretrained(\"THUDM/chatglm3-6b\", trust_remote_code=True)\n>> model = AutoModel.from_pretrained(\"THUDM/chatglm3-6b\", trust_remote_code=True, device='cuda')\n>> model = model.eval()\n>> response, history = model.chat(tokenizer, \"你好\", history=[])\n>> print(response)\n\n你好👋!我是人工智能助手 ChatGLM3 - 6B, 很高兴见到你, 欢迎问我任何问题。\n>> response, history = model.chat(tokenizer, \"晚上睡不着应该怎么办\", history=history)\n>> print(response)\n\n晚上睡不着可能会让你感到焦虑或不舒服, 但以下是一些可以帮助你入睡的方法:\n\n1.制定规律的睡眠时间表: 保持规律的睡眠时间表可以帮助你建立健康的睡眠习惯, 使你更容易入睡。尽量在每天的相同时间上床, 并在同一时间起床。\n2.创造一个舒适的睡眠环境: 确保睡眠环境舒适, 安静, 黑暗且温度适宜。可以使用舒适的床上用品, 并保持房间通风。\n3.放松身心: 在睡前做些放松的活动, 例如泡个热水澡, 听些轻柔的音乐, 阅读一些有趣的书籍等, 有助于缓解紧张和焦虑, 使你更容易入睡。\n4.避免饮用含有咖啡因的饮料: 咖啡因是一种刺激性物质, 会影响你的睡眠质量。尽量避免在睡前饮用含有咖啡因的饮料, 例如咖啡, 茶和可乐。\n5.避免在床上做与睡眠无关的事情: 在床上做些与睡眠无关的事情, 例如看电影, 玩游戏或工作等, 可能会干扰你的睡眠。\n6.尝试呼吸技巧: 深呼吸是一种放松技巧, 可以帮助你缓解紧张和焦虑, 使你更容易入睡。试着慢慢吸气, 保持几秒钟, 然后缓慢呼气。\n\n如果这些方法无法帮助你入睡, 你可以考虑咨询医生或睡眠专家, 寻求进一步的建议。\n```\n\n#### 从本地加载模型\n\n以上代码会由 `transformers`\n自动下载模型实现和参数。完整的模型实现在 [Hugging Face Hub](https://huggingface.co/THUDM/chatglm3-6b)\n。如果你的网络环境较差，下载模型参数可能会花费较长时间甚至失败。此时可以先将模型下载到本地，然后从本地加载。\n\n从 Hugging Face Hub\n下载模型需要先[安装Git LFS](https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large-file-storage)\n，然后运行\n\n```Shell\ngit clone https://huggingface.co/THUDM/chatglm3-6b\n```\n\n如果从你从 HuggingFace 下载比较慢，也可以从 [ModelScope](https://modelscope.cn/models/ZhipuAI/chatglm3-6b)\n中下载。\n\n### 模型微调\n\n我们提供了一个微调 ChatGLM3-6B 模型的基础套件，可以用来微调 ChatGLM3-6B 模型。微调套件的使用方法请参考\n[微调套件](finetune_demo/README.md)。\n\n### 网页版对话 Demo\n\n![web-demo](resources/web-demo.gif)\n可以通过以下命令启动基于 Gradio 的网页版 demo：\n\n```shell\npython web_demo_gradio.py\n```\n\n![web-demo](resources/web-demo2.png)\n\n可以通过以下命令启动基于 Streamlit 的网页版 demo：\n\n```shell\nstreamlit run web_demo_streamlit.py\n```\n\n网页版 demo 会运行一个 Web Server，并输出地址。在浏览器中打开输出的地址即可使用。 经测试，基于 Streamlit 的网页版 Demo 会更流畅。\n\n### 命令行对话 Demo\n\n![cli-demo](resources/cli-demo.png)\n\n运行仓库中 [cli_demo.py](basic_demo/cli_demo.py)：\n\n```shell\npython cli_demo.py\n```\n\n程序会在命令行中进行交互式的对话，在命令行中输入指示并回车即可生成回复，输入 `clear` 可以清空对话历史，输入 `stop` 终止程序。\n\n### LangChain Demo\n\n代码实现请参考 [LangChain Demo](langchain_demo/README.md)。\n\n#### 工具调用\n\n关于工具调用的方法请参考 [工具调用](tools_using_demo/README.md)。\n\n#### OpenAI API / Zhipu API Demo\n\n我们已经推出了 OpenAI / ZhipuAI 格式的 开源模型 API 部署代码，可以作为任意基于 ChatGPT 的应用的后端。\n目前，可以通过运行仓库中的 [api_server.py](openai_api_demo/api_server.py) 进行部署\n\n```shell\ncd openai_api_demo\npython api_server.py\n```\n\n同时，我们也书写了一个示例代码，用来测试API调用的性能。\n\n+ OpenAI 测试脚本：[openai_api_request.py](openai_api_demo/openai_api_request.py)\n+ ZhipuAI 测试脚本：[zhipu_api_request.py](openai_api_demo/zhipu_api_request.py)\n+ 使用Curl进行测试\n  \n+ chat Curl 测试\n```shell\ncurl -X POST \"http://127.0.0.1:8000/v1/chat/completions\" \\\n-H \"Content-Type: application/json\" \\\n-d \"{\\\"model\\\": \\\"chatglm3-6b\\\", \\\"messages\\\": [{\\\"role\\\": \\\"system\\\", \\\"content\\\": \\\"You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.\\\"}, {\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"你好，给我讲一个故事，大概100字\\\"}], \\\"stream\\\": false, \\\"max_tokens\\\": 100, \\\"temperature\\\": 0.8, \\\"top_p\\\": 0.8}\"\n````\n\n+ Standard openai interface agent-chat Curl 测试\n```shell\ncurl -X POST \"http://127.0.0.1:8000/v1/chat/completions\" \\\n-H \"Content-Type: application/json\" \\\n-d \"{\\\"model\\\": \\\"chatglm3-6b\\\", \\\"messages\\\": [{\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"37乘以8加7除2等于多少？\\\"}], \"tools\": [{\"name\": \"track\", \"description\": \"追踪指定股票的实时价格\",\n          \"parameters\": {\"type\": \"object\", \"properties\": {\"symbol\": {\"description\": \"需要追踪的股票代码\"}},\n                         \"required\": []}},\n         {\"name\": \"Calculator\", \"description\": \"数学计算器，计算数学问题\",\n          \"parameters\": {\"type\": \"object\", \"properties\": {\"symbol\": {\"description\": \"要计算的数学公式\"}},\n                         \"required\": []}}\n         ], \\\"stream\\\": true, \\\"max_tokens\\\": 100, \\\"temperature\\\": 0.8, \\\"top_p\\\": 0.8}\"\n````\n\n+ Openai style custom interface agent-chat Curl 测试（你需要实现自定义的工具描述脚本openai_api_demo/tools/schema.py的内容，并且将api_server.py中AGENT_CONTROLLER指定为'true'）：\n```shell\ncurl -X POST \"http://127.0.0.1:8000/v1/chat/completions\" \\\n-H \"Content-Type: application/json\" \\\n-d \"{\\\"model\\\": \\\"chatglm3-6b\\\", \\\"messages\\\": [{\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"37乘以8加7除2等于多少？\\\"}], \\\"stream\\\": true, \\\"max_tokens\\\": 100, \\\"temperature\\\": 0.8, \\\"top_p\\\": 0.8}\"\n````\n该接口用于openai风格的自定义工具箱的自主调度。具有调度异常的自处理回复能力，无需另外实现调度算法，用户无需api_key。\n\n\n+ 使用Python进行测试\n\n```shell\ncd openai_api_demo\npython openai_api_request.py\n```\n\n如果测试成功，则模型应该返回一段故事。\n\n## 低成本部署\n\n### 模型量化\n\n默认情况下，模型以 FP16 精度加载，运行上述代码需要大概 13GB 显存。如果你的 GPU 显存有限，可以尝试以量化方式加载模型，使用方法如下：\n\n```python\nmodel = AutoModel.from_pretrained(\"THUDM/chatglm3-6b\", trust_remote_code=True).quantize(4).cuda()\n```\n\n模型量化会带来一定的性能损失，经过测试，ChatGLM3-6B 在 4-bit 量化下仍然能够进行自然流畅的生成。\n\n### CPU 部署\n\n如果你没有 GPU 硬件的话，也可以在 CPU 上进行推理，但是推理速度会更慢。使用方法如下（需要大概 32GB 内存）\n\n```python\nmodel = AutoModel.from_pretrained(\"THUDM/chatglm3-6b\", trust_remote_code=True).float()\n```\n\n### Mac 部署\n\n对于搭载了 Apple Silicon 或者 AMD GPU 的 Mac，可以使用 MPS 后端来在 GPU 上运行 ChatGLM3-6B。需要参考 Apple\n的 [官方说明](https://developer.apple.com/metal/pytorch) 安装 PyTorch-Nightly（正确的版本号应该是2.x.x.dev2023xxxx，而不是\n2.x.x）。\n\n目前在 MacOS 上只支持[从本地加载模型](README.md#从本地加载模型)。将代码中的模型加载改为从本地加载，并使用 mps 后端：\n\n```python\nmodel = AutoModel.from_pretrained(\"your local path\", trust_remote_code=True).to('mps')\n```\n\n加载半精度的 ChatGLM3-6B 模型需要大概 13GB 内存。内存较小的机器（比如 16GB 内存的 MacBook\nPro），在空余内存不足的情况下会使用硬盘上的虚拟内存，导致推理速度严重变慢。\n\n### 多卡部署\n\n如果你有多张 GPU，但是每张 GPU 的显存大小都不足以容纳完整的模型，那么可以将模型切分在多张GPU上。首先安装\naccelerate: `pip install accelerate`，然后即可正常加载模型。\n\n### OpenVINO Demo\n\nChatGLM3-6B 已经支持使用 OpenVINO\n工具包进行加速推理，在英特尔的GPU和GPU设备上有较大推理速度提升。具体使用方法请参考 [OpenVINO Demo](Intel_device_demo/openvino_demo/README.md)。\n\n### TensorRT-LLM Demo\n\nChatGLM3-6B已经支持使用 TensorRT-LLM\n工具包进行加速推理，模型推理速度得到多倍的提升。具体使用方法请参考 [TensorRT-LLM Demo](tensorrt_llm_demo/tensorrt_llm_cli_demo.py)\n和 官方技术文档。\n\n## 引用\n\n如果你觉得我们的工作有帮助的话，请考虑引用下列论文。\n\n```\n@misc{glm2024chatglm,\n      title={ChatGLM: A Family of Large Language Models from GLM-130B to GLM-4 All Tools}, \n      author={Team GLM and Aohan Zeng and Bin Xu and Bowen Wang and Chenhui Zhang and Da Yin and Diego Rojas and Guanyu Feng and Hanlin Zhao and Hanyu Lai and Hao Yu and Hongning Wang and Jiadai Sun and Jiajie Zhang and Jiale Cheng and Jiayi Gui and Jie Tang and Jing Zhang and Juanzi Li and Lei Zhao and Lindong Wu and Lucen Zhong and Mingdao Liu and Minlie Huang and Peng Zhang and Qinkai Zheng and Rui Lu and Shuaiqi Duan and Shudan Zhang and Shulin Cao and Shuxun Yang and Weng Lam Tam and Wenyi Zhao and Xiao Liu and Xiao Xia and Xiaohan Zhang and Xiaotao Gu and Xin Lv and Xinghan Liu and Xinyi Liu and Xinyue Yang and Xixuan Song and Xunkai Zhang and Yifan An and Yifan Xu and Yilin Niu and Yuantao Yang and Yueyan Li and Yushi Bai and Yuxiao Dong and Zehan Qi and Zhaoyu Wang and Zhen Yang and Zhengxiao Du and Zhenyu Hou and Zihan Wang},\n      year={2024},\n      eprint={2406.12793},\n      archivePrefix={arXiv},\n      primaryClass={id='cs.CL' full_name='Computation and Language' is_active=True alt_name='cmp-lg' in_archive='cs' is_general=False description='Covers natural language processing. Roughly includes material in ACM Subject Class I.2.7. Note that work on artificial languages (programming languages, logics, formal systems) that does not explicitly address natural-language issues broadly construed (natural-language processing, computational linguistics, speech, text retrieval, etc.) is not appropriate for this area.'}\n}\n```\n"
  },
  {
    "path": "README_en.md",
    "content": "# ChatGLM3\n\n<p align=\"center\">\n📄<a href=\"https://arxiv.org/pdf/2406.12793\" target=\"_blank\"> Report </a> • 🤗 <a href=\"https://huggingface.co/THUDM/chatglm3-6b\" target=\"_blank\">HF Repo</a> • 🤖 <a href=\"https://modelscope.cn/models/ZhipuAI/chatglm3-6b\" target=\"_blank\">ModelScope</a> • 🟣 <a href=\"https://www.wisemodel.cn/models/ZhipuAI/chatglm3-6b\" target=\"_blank\">WiseModel</a> • 📔 <a href=\"https://lslfd0slxc.feishu.cn/wiki/WvQbwIJ9tiPAxGk8ywDck6yfnof\" target=\"_blank\">Document</a> •  🧰 <a href=\"https://openxlab.org.cn/models/hot/THUDM\" target=\"_blank\">OpenXLab</a> • 🐦 <a href=\"https://twitter.com/thukeg\" target=\"_blank\">Twitter</a><br>\n</p>\n<p align=\"center\">\n    👋 Join our  <a href=\"https://discord.gg/fK2dz4bg\" target=\"_blank\">Discord</a> and <a href=\"resources/WECHAT.md\" target=\"_blank\">WeChat</a>\n</p>\n<p align=\"center\">\n📍Experience the larger-scale ChatGLM model at <a href=\"https://www.chatglm.cn\">chatglm.cn</a>\n</p>\n\n📔 About `ChatGLM3-6B`For more detailed usage information, please refer to:\n\n+ [ChatGLM3 technical documentation](https://lslfd0slxc.feishu.cn/wiki/WvQbwIJ9tiPAxGk8ywDck6yfnof?from=from_copylink)\n+ [Bilibili video](https://www.bilibili.com/video/BV1uC4y1J7yA)\n+ [YouTube video](https://www.youtube.com/watch?v=Pw9PB6R7ORA)\n\n## GLM-4 Open Source Model and API\n\nWe have released the latest **GLM-4** model, which has made new breakthroughs in multiple indicators. You can directly\nexperience our latest model in the following two channels.\n\n+ [GLM-4 open source model](https://github.com/THUDM/GLM-4) We have open sourced the GLM-4-9B series models, which have\n  significantly improved the performance of various indicators. Welcome to try.\n+ [Zhipu Qingyan](https://chatglm.cn/main/detail?fr=ecology_x) Experience the latest version of GLM-4, including **GLMs,\n  All tools** and other functions.\n+ [API platform](https://open.bigmodel.cn/?utm_campaign=open&_channel_track_key=OWTVNma9) The new generation of API\n  platform has been launched. You can directly experience new models such\n  as `GLM-4-0520`, `GLM-4-air`, `GLM-4-airx`, `GLM-4-flash`, `GLM-4`, `GLM-3-Turbo`, `CharacterGLM-3`, `CogView-3` on\n  the API platform.\n  Among them, the two models `GLM-4` and `GLM-3-Turbo` support new functions such\n  as `System Prompt`, `Function Call`, `Retrieval`, and `Web_Search`. You are welcome to experience them.\n\n+ [GLM4 API open source tutorial](https://github.com/MetaGLM/glm-cookbook/) GLM-4 API tutorial and basic applications,\n  welcome to try.\n  API-related questions can be asked in this open source tutorial, or\n  use [GLM-4 API AI Assistant](https://open.bigmodel.cn/shareapp/v1/?share_code=sQwt5qyqYVaNh1O_87p8O)\n  to get help with common problems.\n\n-----\n\n## ChatGLM3 Introduction\n\n**ChatGLM3** is a generation of pre-trained dialogue models jointly released by Zhipu AI and Tsinghua KEG. ChatGLM3-6B\nis the open-source model in the ChatGLM3 series, maintaining many excellent features of the first two generations such\nas smooth dialogue and low deployment threshold, while introducing the following features:\n\n1. **Stronger Base Model:** The base model of ChatGLM3-6B, ChatGLM3-6B-Base, adopts a more diverse training dataset,\n   more sufficient training steps, and a more reasonable training strategy. Evaluations on datasets from various\n   perspectives such as semantics, mathematics, reasoning, code, and knowledge show that **ChatGLM3-6B-Base has the\n   strongest performance among base models below 10B**.\n\n2. **More Complete Function Support:** ChatGLM3-6B adopts a newly designed [Prompt format](PROMPT_en.md), supporting\n   multi-turn dialogues as usual. It also natively supports [tool invocation](tools_using_demo/README_en.md) (Function\n   Call), code execution (Code Interpreter), and Agent tasks in complex scenarios.\n\n3. **More Comprehensive Open-source Series:** In addition to the dialogue\n   model [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b), the basic\n   model [ChatGLM3-6B-Base](https://huggingface.co/THUDM/chatglm3-6b-base), the long-text dialogue\n   model [ChatGLM3-6B-32K](https://huggingface.co/THUDM/chatglm3-6b-32k) and further strengthens the ability to\n   understand long texts [ChatGLM3-6B-128K](https://huggingface.co/THUDM/chatglm3-6b-128k) have also been open-sourced.\n   All these weights are **fully open** for academic research, and **free commercial use is also allowed** after\n   registration via a [questionnaire](https://open.bigmodel.cn/mla/form).\n\n-----\n\nThe ChatGLM3 open-source model aims to promote the development of large-model technology together with the open-source\ncommunity. Developers and everyone are earnestly requested to comply with the [open-source protocol](MODEL_LICENSE), and\nnot to use the open-source models, codes, and derivatives for any purposes that might harm the nation and society, and\nfor any services that have not been evaluated and filed for safety. Currently, no applications, including web, Android,\nApple iOS, and Windows App, have been developed based on the **ChatGLM3 open-source model** by our project team.\n\nAlthough every effort has been made to ensure the compliance and accuracy of the data at various stages of model\ntraining, due to the smaller scale of the ChatGLM3-6B model and the influence of probabilistic randomness factors, the\naccuracy of output content cannot be guaranteed. The model output is also easily misled by user input. **This project\ndoes not assume risks and liabilities caused by data security, public opinion risks, or any misleading, abuse,\ndissemination, and improper use of open-source models and codes.**\n\n## Model List\n\n|      Model       | Seq Length |                                                              Download                                                               \n|:----------------:|:----------:|:-----------------------------------------------------------------------------------------------------------------------------------:\n|   ChatGLM3-6B    |     8k     |      [HuggingFace](https://huggingface.co/THUDM/chatglm3-6b) \\| [ModelScope](https://modelscope.cn/models/ZhipuAI/chatglm3-6b)      \n| ChatGLM3-6B-Base |     8k     | [HuggingFace](https://huggingface.co/THUDM/chatglm3-6b-base) \\| [ModelScope](https://modelscope.cn/models/ZhipuAI/chatglm3-6b-base) \n| ChatGLM3-6B-32K  |    32k     |  [HuggingFace](https://huggingface.co/THUDM/chatglm3-6b-32k) \\| [ModelScope](https://modelscope.cn/models/ZhipuAI/chatglm3-6b-32k)  \n| ChatGLM3-6B-128K |    128k    | [HuggingFace](https://huggingface.co/THUDM/chatglm3-6b-128k) ｜ [ModelScope](https://modelscope.cn/models/ZhipuAI/chatglm3-6b-128k)  \n\n## Projects\n\nThe following excellent open source repositories have in-depth support for the ChatGLM3-6B model, and everyone is\nwelcome to expand their learning.\n\nInference acceleration:\n\n* [chatglm.cpp](https://github.com/li-plus/chatglm.cpp): Real-time inference on your laptop accelerated by quantization,\n  similar to llama.cpp.\n* [ChatGLM3-TPU](https://github.com/sophgo/ChatGLM3-TPU): Using the TPU accelerated inference solution, it runs about\n  7.5 token/s in real time on the end-side chip BM1684X (16T@FP16, 16G DDR).\n* [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/main): A high-performance GPU-accelerated inference\n  solution developed by NVIDIA, you can refer to these [steps](./tensorrt_llm_demo/README.md) to deploy ChatGLM3.\n* [OpenVINO](https://github.com/openvinotoolkit): A high-performance CPU and GPU accelerated inference solution\n  developed by Intel, you can refer to this [step](./Intel_device_demo/openvino_demo/README.md) to deploy the\n  ChatGLM3-6B model\n\nEfficient fine-tuning:\n\n* [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory): An excellent, easy-to-use and efficient fine-tuning\n  framework.\n\nApplication framework:\n\n* [LangChain-Chatchat](https://github.com/chatchat-space/Langchain-Chatchat): Based on large language models such as\n  ChatGLM and application frameworks such as Langchain, open source and offline deployable retrieval enhancement\n  generation (RAG) large Model knowledge base project.\n\n* [BISHENG](https://github.com/dataelement/bisheng): open-source platform for developing LLM applications. It empowers\n  and accelerates the development of LLM applications and helps users to enter the next generation of application\n  development mode with the best experience.\n\n* [RAGFlow](https://github.com/infiniflow/ragflow): An open-source RAG (Retrieval-Augmented Generation) engine based on\n  deep document understanding. It offers a streamlined RAG workflow for businesses of any scale, combining LLM (Large\n  Language Models) to provide truthful question-answering capabilities, backed by well-founded citations from various\n  complex formatted data.\n\n## Evaluation Results\n\n### Typical Tasks\n\nWe selected 8 typical Chinese-English datasets and conducted performance tests on the ChatGLM3-6B (base) version.\n\n| Model            | GSM8K | MATH | BBH  | MMLU | C-Eval | CMMLU | MBPP | AGIEval |\n|------------------|:-----:|:----:|:----:|:----:|:------:|:-----:|:----:|:-------:|\n| ChatGLM2-6B-Base | 32.4  | 6.5  | 33.7 | 47.9 |  51.7  | 50.0  |  -   |    -    |\n| Best Baseline    | 52.1  | 13.1 | 45.0 | 60.1 |  63.5  | 62.2  | 47.5 |  45.8   |\n| ChatGLM3-6B-Base | 72.3  | 25.7 | 66.1 | 61.4 |  69.0  | 67.5  | 52.4 |  53.7   |\n\n> \"Best Baseline\" refers to the pre-trained models that perform best on the corresponding datasets with model parameters\n> below 10B, excluding models that are trained specifically for a single task and do not maintain general capabilities.\n\n> In the tests of ChatGLM3-6B-Base, BBH used a 3-shot test, GSM8K and MATH that require inference used a 0-shot CoT\n> test, MBPP used a 0-shot generation followed by running test cases to calculate Pass@1, and other multiple-choice type\n> datasets all used a 0-shot test.\n\nWe have conducted manual evaluation tests on ChatGLM3-6B-32K in multiple long-text application scenarios. Compared with\nthe second-generation model, its effect has improved by more than 50% on average. In applications such as paper reading,\ndocument summarization, and financial report analysis, this improvement is particularly significant. In addition, we\nalso tested the model on the LongBench evaluation set, and the specific results are shown in the table below.\n\n| Model           | Average | Summary | Single-Doc QA | Multi-Doc QA | Code | Few-shot | Synthetic | \n|-----------------|:-------:|:-------:|:-------------:|:------------:|:----:|:--------:|:---------:|\n| ChatGLM2-6B-32K |  41.5   |  24.8   |     37.6      |     34.7     | 52.8 |   51.3   |   47.7    | \n| ChatGLM3-6B-32K |  50.2   |  26.6   |     45.8      |     46.1     | 56.2 |   61.2   |    65     |\n\n## How to Use\n\n### Environment Installation\n\nFirst, you need to download this repository:\n\n```shell\ngit clone https://github.com/THUDM/ChatGLM3\ncd ChatGLM3\n```\n\nThen use pip to install the dependencies:\n\n```\npip install -r requirements.txt\n```\n\n+ In order to ensure that the version of `torch` is correct, please strictly follow the instructions\n  of [official documentation](https://pytorch.org/get-started/locally/) for installation.\n\n### Integrated Demo\n\nWe provide an integrated demo that incorporates the following three functionalities. Please refer\nto [Integrated Demo](composite_demo/README_en.md) for how to run it.\n\n- Chat: Dialogue mode, where you can interact with the model.\n- Tool: Tool mode, where in addition to dialogue, the model can also perform other operations using tools.\n  ![tool](resources/tool_en.png)\n- Code Interpreter: Code interpreter mode, where the model can execute code in a Jupyter environment and obtain results\n  to complete complex tasks.\n  ![code](resources/code_en.gif)\n\n### Usage\n\nThe ChatGLM model can be called to start a conversation using the following code:\n\n```\n>> from transformers import AutoTokenizer, AutoModel\n>> tokenizer = AutoTokenizer.from_pretrained(\"THUDM/chatglm3-6b\", trust_remote_code=True)\n>> model = AutoModel.from_pretrained(\"THUDM/chatglm3-6b\", trust_remote_code=True, device='cuda')\n>> model = model.eval()\n>> response, history = model.chat(tokenizer, \"你好\", history=[])\n>> print(response)\n\nHello 👋! I am the AI assistant ChatGLM3-6B, nice to meet you, feel free to ask me any questions.\n\n>> response, history = model.chat(tokenizer, \"What should I do if I can't sleep at night?\", history=history)\n>> print(response)\n\nHaving trouble sleeping at night may make you feel anxious or uncomfortable, but here are some ways to help you fall asleep:\n\n1. Develop a regular sleep schedule: Keeping a regular sleep schedule can help you develop healthy sleep habits and make it easier for you to fall asleep. Try to go to bed at the same time every day and get up at the same time.\n\n2. Create a comfortable sleeping environment: Make sure the sleeping environment is comfortable, quiet, dark and at a suitable temperature. You can use comfortable bedding and keep the room ventilated.\n\n3. Relax your body and mind: Doing some relaxing activities before bed, such as taking a hot bath, listening to some soft music, reading some interesting books, etc., can help relieve tension and anxiety and make it easier for you to fall asleep.\n\n4. Avoid drinking caffeinated beverages: Caffeine is a stimulant that can affect your sleep quality. Try to avoid drinking caffeinated beverages such as coffee, tea and cola before bed.\n\n5. Avoid doing things that are not related to sleep in bed: Doing things that are not related to sleep in bed, such as watching movies, playing games or working, etc., may interfere with your sleep.\n6. Try breathing techniques: Deep breathing is a relaxation technique that can help you relieve tension and anxiety, making it easier for you to fall asleep. Try breathing in slowly, holding it for a few seconds, and then exhaling slowly.\n\nIf these methods don't help you fall asleep, you may consider consulting a doctor or sleep specialist for further advice.\n```\n\n#### Load Model Locally\n\nThe above code will automatically download the model implementation and parameters by `transformers`. The complete model\nimplementation is available on [Hugging Face Hub](https://huggingface.co/THUDM/chatglm3-6b). If your network environment\nis poor, downloading model parameters might take a long time or even fail. In this case, you can first download the\nmodel to your local machine, and then load it from there.\n\nTo download the model from Hugging Face Hub, you need\nto [install Git LFS](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage)\nfirst, then run\n\n```Shell\ngit clone https://huggingface.co/THUDM/chatglm3-6b\n```\n\nIf the download from HuggingFace is slow, you can also download it\nfrom [ModelScope](https://modelscope.cn/models/ZhipuAI/chatglm3-6b).\n\n# Model Fine-tuning\n\nWe provide a basic fine-tuning framework for ChatGLM3-6B. You can use it to fine-tune the model on your own dataset. For\nmore details, please refer to [Fine-tuning Demo](finetune_demo/README_en.md).\n\n### Web-based Dialogue Demo\n\n![web-demo](resources/web-demo.gif)\nYou can launch a web-based demo using Gradio with the following command:\n\n```shell\npython web_demo_gradio.py\n```\n\n![web-demo](resources/web-demo2.png)\n\nYou can launch a web-based demo using Streamlit with the following command:\n\n```shell\nstreamlit run web_demo_streamlit.py\n```\n\nThe web-based demo will run a Web Server and output an address. You can use it by opening the output address in a\nbrowser. Based on tests, the web-based demo using Streamlit runs more smoothly.\n\n### Command Line Dialogue Demo\n\n![cli-demo](resources/cli-demo.png)\n\nRun [cli_demo.py](basic_demo/cli_demo.py) in the repository:\n\n```shell\npython cli_demo.py\n```\n\nThe program will interact in the command line, enter instructions in the command line and hit enter to generate a\nresponse. Enter `clear` to clear the dialogue history, enter `stop` to terminate the program.\n\n### OpenAI API /Zhipu API Demo\n\nWe have launched open source model API deployment code in OpenAI / ZhipuAI format, which can be used as the backend of\nany ChatGPT-based application.\nCurrently, you can deploy by running [api_server.py](openai_api_demo/api_server.py) in the warehouse\n\n```shell\ncd openai_api_demo\npython api_server.py\n```\n\nAt the same time, we also wrote a sample code to test the performance of API calls.\n\n+ OpenAI test script: [openai_api_request.py](openai_api_demo/openai_api_request.py)\n+ ZhipuAI test script: [zhipu_api_request.py](openai_api_demo/zhipu_api_request.py)\n+ Test with Curl\n\n+ chat Curl test\n\n```shell\ncurl -X POST \"http://127.0.0.1:8000/v1/chat/completions\" \\\n-H \"Content-Type: application/json\" \\\n-d \"{\\\"model\\\": \\\"chatglm3-6b\\\", \\\"messages\\\": [{\\\"role\\\": \\\"system\\\", \\\"content\\\": \\\"You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.\\\"}, {\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"你好，给我讲一个故事，大概100字\\\"}], \\\"stream\\\": false, \\\"max_tokens\\\": 100, \\\"temperature\\\": 0.8, \\\"top_p\\\": 0.8}\"\n````\n\n+ Standard openai interface agent-chat Curl test\n\n```shell\ncurl -X POST \"http://127.0.0.1:8000/v1/chat/completions\" \\\n-H \"Content-Type: application/json\" \\\n-d \"{\\\"model\\\": \\\"chatglm3-6b\\\", \\\"messages\\\": [{\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"37乘以8加7除2等于多少？\\\"}], \"tools\": [{\"name\": \"track\", \"description\": \"追踪指定股票的实时价格\",\n          \"parameters\": {\"type\": \"object\", \"properties\": {\"symbol\": {\"description\": \"需要追踪的股票代码\"}},\n                         \"required\": []}},\n         {\"name\": \"Calculator\", \"description\": \"数学计算器，计算数学问题\",\n          \"parameters\": {\"type\": \"object\", \"properties\": {\"symbol\": {\"description\": \"要计算的数学公式\"}},\n                         \"required\": []}}\n         ], \\\"stream\\\": true, \\\"max_tokens\\\": 100, \\\"temperature\\\": 0.8, \\\"top_p\\\": 0.8}\"\n````\n\n+ Openai style custom interface agent-chat Curl test (You need to implement the contents of the custom tool description\n  script openai_api_demo/tools/schema.py, and specify AGENT_CONTROLLER in api_server.py as 'true')：\n\n```shell\ncurl -X POST \"http://127.0.0.1:8000/v1/chat/completions\" \\\n-H \"Content-Type: application/json\" \\\n-d \"{\\\"model\\\": \\\"chatglm3-6b\\\", \\\"messages\\\": [{\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"37乘以8加7除2等于多少？\\\"}], \\\"stream\\\": true, \\\"max_tokens\\\": 100, \\\"temperature\\\": 0.8, \\\"top_p\\\": 0.8}\"\n````\n\nThis interface is used for autonomous scheduling of OpenAI-style custom toolboxes. It has the ability to self-process\nand respond to scheduling exceptions, without the need to implement additional scheduling algorithms, and users do not\nneed an api_key.\n\n+ Testing with Python\n\n```shell\ncd openai_api_demo\npython openai_api_request.py\n```\n\nIf the test is successful, the model should return a story.\n\n### Tool Invocation\n\nFor methods of tool invocation, please refer to [Tool Invocation](tools_using_demo/README_en.md).\n\n## Low-Cost Deployment\n\n### Model Quantization\n\nBy default, the model is loaded with FP16 precision, running the above code requires about 13GB of VRAM. If your GPU's\nVRAM is limited, you can try loading the model quantitatively, as follows:\n\n```python\nmodel = AutoModel.from_pretrained(\"THUDM/chatglm3-6b\", trust_remote_code=True).quantize(4).cuda()\n```\n\nModel quantization will bring some performance loss. Through testing, ChatGLM3-6B can still perform natural and smooth\ngeneration under 4-bit quantization.\n\n### CPU Deployment\n\nIf you don't have GPU hardware, you can also run inference on the CPU, but the inference speed will be slower. The usage\nis as follows (requires about 32GB of memory):\n\n```python\nmodel = AutoModel.from_pretrained(\"THUDM/chatglm3-6b\", trust_remote_code=True).float()\n```\n\n### Mac Deployment\n\nFor Macs equipped with Apple Silicon or AMD GPUs, the MPS backend can be used to run ChatGLM3-6B on the GPU. Refer to\nApple's [official instructions](https://developer.apple.com/metal/pytorch) to install PyTorch-Nightly (the correct\nversion number should be 2.x.x.dev2023xxxx, not 2.x.x).\n\nCurrently, only [loading the model locally](README_en.md#load-model-locally) is supported on MacOS. Change the model\nloading in the code to load locally and use the MPS backend:\n\n```python\nmodel = AutoModel.from_pretrained(\"your local path\", trust_remote_code=True).to('mps')\n```\n\nLoading the half-precision ChatGLM3-6B model requires about 13GB of memory. Machines with smaller memory (such as a 16GB\nmemory MacBook Pro) will use virtual memory on the hard disk when there is insufficient free memory, resulting in a\nsignificant slowdown in inference speed.\n\n### Multi-GPU Deployment\n\nIf you have multiple GPUs, but each GPU's VRAM size is not enough to accommodate the complete model, then the model can\nbe split across multiple GPUs. First, install accelerate: `pip install accelerate`, and then load the model as usual.\n\n### OpenVINO Demo\n\nChatGLM3-6B already supports the use of OpenVINO\nThe toolkit accelerates inference and has a greater inference speed improvement on Intel's GPUs and GPU devices. For\nspecific usage, please refer to [OpenVINO Demo](Intel_device_demo/openvino_demo/README.md).\n\n### TensorRT-LLM Demo\n\nChatGLM3-6B now supports accelerated inference using the TensorRT-LLM toolkit, significantly improving model inference\nspeed. For specific usage, please refer to the [TensorRT-LLM Demo](tensorrt_llm_demo/tensorrt_llm_cli_demo.py) and the\nofficial technical documentation.\n\n## Citation\n\nIf you find our work helpful, please consider citing the following papers.\n\n```\n@misc{glm2024chatglm,\n      title={ChatGLM: A Family of Large Language Models from GLM-130B to GLM-4 All Tools}, \n      author={Team GLM and Aohan Zeng and Bin Xu and Bowen Wang and Chenhui Zhang and Da Yin and Diego Rojas and Guanyu Feng and Hanlin Zhao and Hanyu Lai and Hao Yu and Hongning Wang and Jiadai Sun and Jiajie Zhang and Jiale Cheng and Jiayi Gui and Jie Tang and Jing Zhang and Juanzi Li and Lei Zhao and Lindong Wu and Lucen Zhong and Mingdao Liu and Minlie Huang and Peng Zhang and Qinkai Zheng and Rui Lu and Shuaiqi Duan and Shudan Zhang and Shulin Cao and Shuxun Yang and Weng Lam Tam and Wenyi Zhao and Xiao Liu and Xiao Xia and Xiaohan Zhang and Xiaotao Gu and Xin Lv and Xinghan Liu and Xinyi Liu and Xinyue Yang and Xixuan Song and Xunkai Zhang and Yifan An and Yifan Xu and Yilin Niu and Yuantao Yang and Yueyan Li and Yushi Bai and Yuxiao Dong and Zehan Qi and Zhaoyu Wang and Zhen Yang and Zhengxiao Du and Zhenyu Hou and Zihan Wang},\n      year={2024},\n      eprint={2406.12793},\n      archivePrefix={arXiv},\n      primaryClass={id='cs.CL' full_name='Computation and Language' is_active=True alt_name='cmp-lg' in_archive='cs' is_general=False description='Covers natural language processing. Roughly includes material in ACM Subject Class I.2.7. Note that work on artificial languages (programming languages, logics, formal systems) that does not explicitly address natural-language issues broadly construed (natural-language processing, computational linguistics, speech, text retrieval, etc.) is not appropriate for this area.'}\n}\n```"
  },
  {
    "path": "basic_demo/cli_batch_request_demo.py",
    "content": "import os\nfrom typing import Optional, Union\nfrom transformers import AutoModel, AutoTokenizer, LogitsProcessorList\n\nMODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')\nTOKENIZER_PATH = os.environ.get(\"TOKENIZER_PATH\", MODEL_PATH)\n\ntokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)\nmodel = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map=\"auto\").eval()\n\ndef batch(\n        model,\n        tokenizer,\n        prompts: Union[str, list[str]],\n        max_length: int = 8192,\n        num_beams: int = 1,\n        do_sample: bool = True,\n        top_p: float = 0.8,\n        temperature: float = 0.8,\n        logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),\n):\n    tokenizer.encode_special_tokens = True\n    if isinstance(prompts, str):\n        prompts = [prompts]\n    batched_inputs = tokenizer(prompts, return_tensors=\"pt\", padding=\"longest\")\n    batched_inputs = batched_inputs.to(model.device)\n\n    eos_token_id = [\n        tokenizer.eos_token_id,\n        tokenizer.get_command(\"<|user|>\"),\n        tokenizer.get_command(\"<|assistant|>\"),\n    ]\n    gen_kwargs = {\n        \"max_length\": max_length,\n        \"num_beams\": num_beams,\n        \"do_sample\": do_sample,\n        \"top_p\": top_p,\n        \"temperature\": temperature,\n        \"logits_processor\": logits_processor,\n        \"eos_token_id\": eos_token_id,\n    }\n    batched_outputs = model.generate(**batched_inputs, **gen_kwargs)\n    batched_response = []\n    for input_ids, output_ids in zip(batched_inputs.input_ids, batched_outputs):\n        decoded_text = tokenizer.decode(output_ids[len(input_ids):])\n        batched_response.append(decoded_text.strip())\n    return batched_response\n\n\ndef main(batch_queries):\n    gen_kwargs = {\n        \"max_length\": 2048,\n        \"do_sample\": True,\n        \"top_p\": 0.8,\n        \"temperature\": 0.8,\n        \"num_beams\": 1,\n    }\n    batch_responses = batch(model, tokenizer, batch_queries, **gen_kwargs)\n    return batch_responses\n\n\nif __name__ == \"__main__\":\n    batch_queries = [\n        \"<|user|>\\n讲个故事\\n<|assistant|>\",\n        \"<|user|>\\n讲个爱情故事\\n<|assistant|>\",\n        \"<|user|>\\n讲个开心故事\\n<|assistant|>\",\n        \"<|user|>\\n讲个睡前故事\\n<|assistant|>\",\n        \"<|user|>\\n讲个励志的故事\\n<|assistant|>\",\n        \"<|user|>\\n讲个少壮不努力的故事\\n<|assistant|>\",\n        \"<|user|>\\n讲个青春校园恋爱故事\\n<|assistant|>\",\n        \"<|user|>\\n讲个工作故事\\n<|assistant|>\",\n        \"<|user|>\\n讲个旅游的故事\\n<|assistant|>\",\n    ]\n    batch_responses = main(batch_queries)\n    for response in batch_responses:\n        print(\"=\" * 10)\n        print(response)\n"
  },
  {
    "path": "basic_demo/cli_demo.py",
    "content": "import os\nimport platform\nfrom transformers import AutoTokenizer, AutoModel\n\nMODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')\nTOKENIZER_PATH = os.environ.get(\"TOKENIZER_PATH\", MODEL_PATH)\n\ntokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)\nmodel = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map=\"auto\").eval()\n# add .quantize(bits=4, device=\"cuda\").cuda() before .eval() to use int4 model\n# must use cuda to load int4 model\n\nos_name = platform.system()\nclear_command = 'cls' if os_name == 'Windows' else 'clear'\nstop_stream = False\n\nwelcome_prompt = \"欢迎使用 ChatGLM3-6B 模型，输入内容即可进行对话，clear 清空对话历史，stop 终止程序\"\n\n\ndef build_prompt(history):\n    prompt = welcome_prompt\n    for query, response in history:\n        prompt += f\"\\n\\n用户：{query}\"\n        prompt += f\"\\n\\nChatGLM3-6B：{response}\"\n    return prompt\n\n\ndef main():\n    past_key_values, history = None, []\n    global stop_stream\n    print(welcome_prompt)\n    while True:\n        query = input(\"\\n用户：\")\n        if query.strip() == \"stop\":\n            break\n        if query.strip() == \"clear\":\n            past_key_values, history = None, []\n            os.system(clear_command)\n            print(welcome_prompt)\n            continue\n        print(\"\\nChatGLM：\", end=\"\")\n        current_length = 0\n        for response, history, past_key_values in model.stream_chat(tokenizer, query, history=history, top_p=1,\n                                                                    temperature=0.01,\n                                                                    past_key_values=past_key_values,\n                                                                    return_past_key_values=True):\n            if stop_stream:\n                stop_stream = False\n                break\n            else:\n                print(response[current_length:], end=\"\", flush=True)\n                current_length = len(response)\n        print(\"\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "basic_demo/cli_demo_bad_word_ids.py",
    "content": "\"\"\"\nThis script demonstrates how to use the `bad_words_ids` argument in the context of a conversational AI model to filter out unwanted words or phrases from the model's responses. It's designed to showcase a fundamental method of content moderation within AI-generated text, particularly useful in scenarios where maintaining the decorum of the conversation is essential.\n\nUsage:\n\n- Interact with the model by typing queries. The model will generate responses while avoiding the specified bad words.\n- Use 'clear' to clear the conversation history and 'stop' to exit the program.\n\nRequirements:\n- The script requires the Transformers library and an appropriate model checkpoint.\n\nNote: The `bad_words_ids` feature is an essential tool for controlling the output of language models, particularly in user-facing applications where content moderation is crucial.\n\"\"\"\nimport os\nimport platform\n\nfrom transformers import AutoTokenizer, AutoModel\n\nMODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')\nTOKENIZER_PATH = os.environ.get(\"TOKENIZER_PATH\", MODEL_PATH)\n\ntokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)\nmodel = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map=\"auto\").eval()\n\nos_name = platform.system()\nclear_command = 'cls' if os_name == 'Windows' else 'clear'\nstop_stream = False\n\nwelcome_prompt = \"欢迎使用 ChatGLM3-6B 模型，输入内容即可进行对话，clear 清空对话历史，stop 终止程序\"\n\n# probability tensor contains either `inf`, `nan` or element < 0\n\nbad_words = [\"你好\", \"ChatGLM\"]\nbad_word_ids = [tokenizer.encode(bad_word, add_special_tokens=False) for bad_word in bad_words]\n\n\ndef build_prompt(history):\n    prompt = welcome_prompt\n    for query, response in history:\n        prompt += f\"\\n\\n用户：{query}\"\n        prompt += f\"\\n\\nChatGLM3-6B：{response}\"\n    return prompt\n\ndef main():\n    past_key_values, history = None, []\n    global stop_stream\n    print(welcome_prompt)\n    while True:\n        query = input(\"\\n用户：\")\n        if query.strip().lower() == \"stop\":\n            break\n        if query.strip().lower() == \"clear\":\n            past_key_values, history = None, []\n            os.system(clear_command)\n            print(welcome_prompt)\n            continue\n\n        # Attempt to generate a response\n        try:\n            print(\"\\nChatGLM：\", end=\"\")\n            current_length = 0\n            response_generated = False\n            for response, history, past_key_values in model.stream_chat(\n                tokenizer, query, history=history, top_p=1,\n                temperature=0.01,\n                past_key_values=past_key_values,\n                return_past_key_values=True,\n                bad_words_ids=bad_word_ids  # assuming this is implemented correctly\n            ):\n                response_generated = True\n                # Check if the response contains any bad words\n                if any(bad_word in response for bad_word in bad_words):\n                    print(\"我的回答涉嫌了 bad word\")\n                    break  # Break the loop if a bad word is detected\n\n                # Otherwise, print the generated response\n                print(response[current_length:], end=\"\", flush=True)\n                current_length = len(response)\n            if not response_generated:\n                print(\"没有生成任何回答。\")\n        except RuntimeError as e:\n            print(f\"生成文本时发生错误：{e}，这可能是涉及到设定的敏感词汇\")\n\n        print(\"\")\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "basic_demo/web_demo_gradio.py",
    "content": "\"\"\"\nThis script creates an interactive web demo for the ChatGLM3-6B model using Gradio,\na Python library for building quick and easy UI components for machine learning models.\nIt's designed to showcase the capabilities of the ChatGLM3-6B model in a user-friendly interface,\nallowing users to interact with the model through a chat-like interface.\n\nUsage:\n- Run the script to start the Gradio web server.\n- Interact with the model by typing questions and receiving responses.\n\nRequirements:\n- Gradio (required for 4.13.0 and later, 3.x is not support now) should be installed.\n\nNote: The script includes a modification to the Chatbot's postprocess method to handle markdown to HTML conversion,\nensuring that the chat interface displays formatted text correctly.\n\n\"\"\"\n\nimport os\nimport gradio as gr\nimport torch\nfrom threading import Thread\n\nfrom typing import Union, Annotated\nfrom pathlib import Path\nfrom peft import AutoPeftModelForCausalLM, PeftModelForCausalLM\nfrom transformers import (\n    AutoModelForCausalLM,\n    AutoTokenizer,\n    PreTrainedModel,\n    PreTrainedTokenizer,\n    PreTrainedTokenizerFast,\n    StoppingCriteria,\n    StoppingCriteriaList,\n    TextIteratorStreamer\n)\n\nimport socket\n\nModelType = Union[PreTrainedModel, PeftModelForCausalLM]\nTokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]\n\n#MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')\nMODEL_PATH = 'chatglm3-6b'\nTOKENIZER_PATH = os.environ.get(\"TOKENIZER_PATH\", MODEL_PATH)\n\n\ndef _resolve_path(path: Union[str, Path]) -> Path:\n    return Path(path).expanduser().resolve()\n\n\ndef load_model_and_tokenizer(\n        model_dir: Union[str, Path], trust_remote_code: bool = True\n) -> tuple[ModelType, TokenizerType]:\n    model_dir = _resolve_path(model_dir)\n    if (model_dir / 'adapter_config.json').exists():\n        model = AutoPeftModelForCausalLM.from_pretrained(\n            model_dir, trust_remote_code=trust_remote_code, device_map='auto'\n        )\n        tokenizer_dir = model.peft_config['default'].base_model_name_or_path\n    else:\n        model = AutoModelForCausalLM.from_pretrained(\n            model_dir, trust_remote_code=trust_remote_code, device_map='auto'\n        )\n        tokenizer_dir = model_dir\n    tokenizer = AutoTokenizer.from_pretrained(\n        tokenizer_dir, trust_remote_code=trust_remote_code\n    )\n    return model, tokenizer\n\n\nmodel, tokenizer = load_model_and_tokenizer(MODEL_PATH, trust_remote_code=True)\n\n\nclass StopOnTokens(StoppingCriteria):\n    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:\n        stop_ids = [0, 2]\n        for stop_id in stop_ids:\n            if input_ids[0][-1] == stop_id:\n                return True\n        return False\n\n\ndef parse_text(text):\n    lines = text.split(\"\\n\")\n    lines = [line for line in lines if line != \"\"]\n    count = 0\n    for i, line in enumerate(lines):\n        if \"```\" in line:\n            count += 1\n            items = line.split('`')\n            if count % 2 == 1:\n                lines[i] = f'<pre><code class=\"language-{items[-1]}\">'\n            else:\n                lines[i] = f'<br></code></pre>'\n        else:\n            if i > 0:\n                if count % 2 == 1:\n                    line = line.replace(\"`\", \"\\`\")\n                    line = line.replace(\"<\", \"&lt;\")\n                    line = line.replace(\">\", \"&gt;\")\n                    line = line.replace(\" \", \"&nbsp;\")\n                    line = line.replace(\"*\", \"&ast;\")\n                    line = line.replace(\"_\", \"&lowbar;\")\n                    line = line.replace(\"-\", \"&#45;\")\n                    line = line.replace(\".\", \"&#46;\")\n                    line = line.replace(\"!\", \"&#33;\")\n                    line = line.replace(\"(\", \"&#40;\")\n                    line = line.replace(\")\", \"&#41;\")\n                    line = line.replace(\"$\", \"&#36;\")\n                lines[i] = \"<br>\" + line\n    text = \"\".join(lines)\n    return text\n\ndef predict(history, max_length, top_p, temperature, system_prompt):\n    stop = StopOnTokens()\n    messages = []\n    if(system_prompt!=\"\"):\n        messages.append({\"role\": \"system\", \"content\": system_prompt})\n        \n    for idx, (user_msg, model_msg) in enumerate(history):\n        if idx == len(history) - 1 and not model_msg:\n            messages.append({\"role\": \"user\", \"content\": user_msg})\n            break\n        if user_msg:\n            messages.append({\"role\": \"user\", \"content\": user_msg})\n        if model_msg:\n            messages.append({\"role\": \"assistant\", \"content\": model_msg})\n\n    print(\"\\n\\n====conversation====\\n\", messages)\n    model_inputs = tokenizer.apply_chat_template(messages,\n                                                 add_generation_prompt=True,\n                                                 tokenize=True,\n                                                 return_tensors=\"pt\").to(next(model.parameters()).device)\n    streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)\n    generate_kwargs = {\n        \"input_ids\": model_inputs,\n        \"streamer\": streamer,\n        \"max_new_tokens\": max_length,\n        \"do_sample\": True,\n        \"top_p\": top_p,\n        \"temperature\": temperature,\n        \"stopping_criteria\": StoppingCriteriaList([stop]),\n        \"repetition_penalty\": 1.2,\n    }\n    t = Thread(target=model.generate, kwargs=generate_kwargs)\n    t.start()\n\n    for new_token in streamer:\n        if new_token != '':\n            history[-1][1] += new_token\n            yield history\n\n\nwith gr.Blocks(title=\"ChatGLM\") as demo:\n    gr.Markdown(\"## ChatGLM3-6B\")\n\n    with gr.Row():\n        with gr.Column(scale=4):\n            chatbot = gr.Chatbot(layout=\"panel\")\n            with gr.Column(scale=12):\n                user_input = gr.Textbox(show_label=False, placeholder=\"Input to chat...\", lines=3, container=False)\n            with gr.Column(min_width=32, scale=1):\n                submitBtn = gr.Button(\"Submit\", variant=\"primary\")\n        with gr.Column(scale=1):\n            emptyBtn = gr.Button(\"Clear History\")\n            max_length = gr.Slider(0, 32768, value=16384, step=1.0, label=\"Maximum length\", interactive=True)\n            top_p = gr.Slider(0, 1, value=0.8, step=0.01, label=\"Top P\", interactive=True)\n            temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label=\"Temperature\", interactive=True)\n            gr.HTML(\"\"\"<span>System Prompt</span>\"\"\")\n            system_prompt = gr.Textbox(show_label=False, placeholder=\"System Prompt\", lines=6, container=False)\n\n    def user(query, history):\n        return \"\", history + [[parse_text(query), \"\"]]\n\n\n    submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(\n        predict, [chatbot, max_length, top_p, temperature, system_prompt], chatbot\n    )\n    emptyBtn.click(lambda: None, None, chatbot, queue=False)\n\ndemo.queue()\ndemo.launch(server_name=socket.gethostbyname(socket.gethostname()), server_port=7870, inbrowser=True, share=False)\n"
  },
  {
    "path": "basic_demo/web_demo_streamlit.py",
    "content": "\"\"\"\nThis script is a simple web demo based on Streamlit, showcasing the use of the ChatGLM3-6B model. For a more comprehensive web demo,\nit is recommended to use 'composite_demo'.\n\nUsage:\n- Run the script using Streamlit: `streamlit run web_demo_streamlit.py`\n- Adjust the model parameters from the sidebar.\n- Enter questions in the chat input box and interact with the ChatGLM3-6B model.\n\nNote: Ensure 'streamlit' and 'transformers' libraries are installed and the required model checkpoints are available.\n\"\"\"\n\nimport os\nimport streamlit as st\nimport torch\nfrom transformers import AutoModel, AutoTokenizer\n\nMODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')\nTOKENIZER_PATH = os.environ.get(\"TOKENIZER_PATH\", MODEL_PATH)\n\nst.set_page_config(\n    page_title=\"ChatGLM3-6B Streamlit Simple Demo\",\n    page_icon=\":robot:\",\n    layout=\"wide\"\n)\n\n\n@st.cache_resource\ndef get_model():\n\n    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)\n    model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map=\"auto\").eval()\n    return tokenizer, model\n\n\n# 加载Chatglm3的model和tokenizer\ntokenizer, model = get_model()\n\nif \"history\" not in st.session_state:\n    st.session_state.history = []\nif \"past_key_values\" not in st.session_state:\n    st.session_state.past_key_values = None\n\nmax_length = st.sidebar.slider(\"max_length\", 0, 32768, 8192, step=1)\ntop_p = st.sidebar.slider(\"top_p\", 0.0, 1.0, 0.8, step=0.01)\ntemperature = st.sidebar.slider(\"temperature\", 0.0, 1.0, 0.6, step=0.01)\n\nbuttonClean = st.sidebar.button(\"清理会话历史\", key=\"clean\")\nif buttonClean:\n    st.session_state.history = []\n    st.session_state.past_key_values = None\n    if torch.cuda.is_available():\n        torch.cuda.empty_cache()\n    st.rerun()\n\nfor i, message in enumerate(st.session_state.history):\n    if message[\"role\"] == \"user\":\n        with st.chat_message(name=\"user\", avatar=\"user\"):\n            st.markdown(message[\"content\"])\n    else:\n        with st.chat_message(name=\"assistant\", avatar=\"assistant\"):\n            st.markdown(message[\"content\"])\n\nwith st.chat_message(name=\"user\", avatar=\"user\"):\n    input_placeholder = st.empty()\nwith st.chat_message(name=\"assistant\", avatar=\"assistant\"):\n    message_placeholder = st.empty()\n\nprompt_text = st.chat_input(\"请输入您的问题\")\nif prompt_text:\n    input_placeholder.markdown(prompt_text)\n    history = st.session_state.history\n    past_key_values = st.session_state.past_key_values\n    for response, history, past_key_values in model.stream_chat(\n            tokenizer,\n            prompt_text,\n            history,\n            past_key_values=past_key_values,\n            max_length=max_length,\n            top_p=top_p,\n            temperature=temperature,\n            return_past_key_values=True,\n    ):\n        message_placeholder.markdown(response)\n    st.session_state.history = history\n    st.session_state.past_key_values = past_key_values\n"
  },
  {
    "path": "composite_demo/.streamlit/config.toml",
    "content": "[theme]\nfont = \"monospace\""
  },
  {
    "path": "composite_demo/README.md",
    "content": "# ChatGLM3 Web Demo\n\n![Demo webpage](assets/demo.png)\n\n## 安装\n\n我们建议通过 [Conda](https://docs.conda.io/en/latest/) 进行环境管理。\n\n执行以下命令新建一个 conda 环境并安装所需依赖：\n\n```bash\nconda create -n chatglm3-demo python=3.10\nconda activate chatglm3-demo\npip install -r requirements.txt\n```\n\n请注意，本项目需要 Python 3.10 或更高版本。\n\n此外，使用 Code Interpreter 还需要安装 Jupyter 内核：\n\n```bash\nipython kernel install --name chatglm3-demo --user\n```\n\n## 运行\n\n运行以下命令在本地加载模型并启动 demo：\n\n```bash\nstreamlit run main.py\n```\n\n之后即可从命令行中看到 demo 的地址，点击即可访问。初次访问需要下载并加载模型，可能需要花费一定时间。\n\n如果已经在本地下载了模型，可以通过 `export MODEL_PATH=/path/to/model` 来指定从本地加载模型。如果需要自定义 Jupyter 内核，可以通过 `export IPYKERNEL=<kernel_name>` 来指定。\n\n## 使用\n\nChatGLM3 Demo 拥有三种模式：\n\n- Chat: 对话模式，在此模式下可以与模型进行对话。\n- Tool: 工具模式，模型除了对话外，还可以通过工具进行其他操作。\n- Code Interpreter: 代码解释器模式，模型可以在一个 Jupyter 环境中执行代码并获取结果，以完成复杂任务。\n\n### 对话模式\n\n对话模式下，用户可以直接在侧边栏修改 top_p, temperature, System Prompt 等参数来调整模型的行为。例如\n\n![The model responses following system prompt](assets/emojis.png)\n\n### 工具模式\n\n可以通过在 `tool_registry.py` 中注册新的工具来增强模型的能力。只需要使用 `@register_tool` 装饰函数即可完成注册。对于工具声明，函数名称即为工具的名称，函数 docstring 即为工具的说明；对于工具的参数，使用 `Annotated[typ: type, description: str, required: bool]` 标注参数的类型、描述和是否必须。\n\n例如，`get_weather` 工具的注册如下：\n\n```python\n@register_tool\ndef get_weather(\n    city_name: Annotated[str, 'The name of the city to be queried', True],\n) -> str:\n    \"\"\"\n    Get the weather for `city_name` in the following week\n    \"\"\"\n    ...\n```\n\n![The model uses tool to query the weather of pairs.](assets/tool.png)\n\n此外，你也可以在页面中通过 `Manual mode` 进入手动模式，在这一模式下你可以通过 YAML 来直接指定工具列表，但你需要手动将工具的输出反馈给模型。\n\n### 代码解释器模式\n\n由于拥有代码执行环境，此模式下的模型能够执行更为复杂的任务，例如绘制图表、执行符号运算等等。模型会根据对任务完成情况的理解自动地连续执行多个代码块，直到任务完成。因此，在这一模式下，你只需要指明希望模型执行的任务即可。\n\n例如，我们可以让 ChatGLM3 画一个爱心：\n\n![The code interpreter draws a heart according to the user's instructions.](assets/heart.png)\n\n### 额外技巧\n\n- 在模型生成文本时，可以通过页面右上角的 `Stop` 按钮进行打断。\n- 刷新页面即可清空对话记录。\n\n# Enjoy!"
  },
  {
    "path": "composite_demo/README_en.md",
    "content": "# ChatGLM3 Web Demo\n\n![Demo webpage](assets/demo.png)\n\n## Installation\n\nWe recommend managing environments through [Conda](https://docs.conda.io/en/latest/).\n\nExecute the following commands to create a new conda environment and install the necessary dependencies:\n\n```bash\nconda create -n chatglm3-demo python=3.10\nconda activate chatglm3-demo\npip install -r requirements.txt\n```\n\nPlease note that this project requires Python 3.10 or higher.\n\nAdditionally, installing the Jupyter kernel is required for using the Code Interpreter:\n\n```bash\nipython kernel install --name chatglm3-demo --user\n```\n\n## Execution\n\nRun the following command to load the model locally and start the demo:\n\n```bash\nstreamlit run main.py\n```\n\nAfterward, the address of the demo can be seen from the command line; click to access. The first visit requires the download and loading of the model, which may take some time.\n\nIf the model has already been downloaded locally, you can specify to load the model locally through `export MODEL_PATH=/path/to/model`. If you need to customize the Jupyter kernel, you can specify it through `export IPYKERNEL=<kernel_name>`.\n\n## Usage\n\nChatGLM3 Demo has three modes:\n\n- Chat: Dialogue mode, where you can interact with the model.\n- Tool: Tool mode, where the model, in addition to dialogue, can perform other operations through tools.\n- Code Interpreter: Code interpreter mode, where the model can execute code in a Jupyter environment and obtain results to complete complex tasks.\n\n### Dialogue Mode\n\nIn dialogue mode, users can directly modify parameters such as top_p, temperature, System Prompt in the sidebar to adjust the behavior of the model. For example,\n\n![The model responses following system prompt](assets/emojis.png)\n\n### Tool Mode\n\nYou can enhance the model's capabilities by registering new tools in `tool_registry.py`. Just use the `@register_tool` decorator to complete the registration. For tool declarations, the function name is the name of the tool, and the function docstring is the description of the tool; for tool parameters, use `Annotated[typ: type, description: str, required: bool]` to annotate the type, description, and whether it is necessary of the parameters.\n\nFor example, the registration of the `get_weather` tool is as follows:\n\n```python\n@register_tool\ndef get_weather(\n    city_name: Annotated[str, 'The name of the city to be queried', True],\n) -> str:\n    \"\"\"\n    Get the weather for `city_name` in the following week\n    \"\"\"\n    ...\n```\n\n![The model uses tool to query the weather of pairs.](assets/tool.png)\n\nAdditionally, you can enter the manual mode through `Manual mode` on the page. In this mode, you can directly specify the tool list through YAML, but you need to manually feed back the tool's output to the model.\n\n### Code Interpreter Mode\n\nDue to having a code execution environment, the model in this mode can perform more complex tasks, such as drawing charts, performing symbolic operations, etc. The model will automatically execute multiple code blocks in succession based on its understanding of the task completion status until the task is completed. Therefore, in this mode, you only need to specify the task you want the model to perform.\n\nFor example, we can ask ChatGLM3 to draw a heart:\n\n![The code interpreter draws a heart according to the user's instructions.](assets/heart.png)\n\n### Additional Tips\n\n- While the model is generating text, it can be interrupted by the `Stop` button at the top right corner of the page.\n- Refreshing the page will clear the dialogue history.\n\n# Enjoy!"
  },
  {
    "path": "composite_demo/client.py",
    "content": "from __future__ import annotations\n\nimport os\nimport streamlit as st\nimport torch\n\nfrom collections.abc import Iterable\nfrom typing import Any, Protocol\nfrom huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token\nfrom transformers import AutoModel, AutoTokenizer, AutoConfig\nfrom transformers.generation.logits_process import LogitsProcessor\nfrom transformers.generation.utils import LogitsProcessorList\n\nfrom conversation import Conversation\n\nTOOL_PROMPT = 'Answer the following questions as best as you can. You have access to the following tools:'\n\nMODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')\nPT_PATH = os.environ.get('PT_PATH', None)\nPRE_SEQ_LEN = int(os.environ.get(\"PRE_SEQ_LEN\", 128))\nTOKENIZER_PATH = os.environ.get(\"TOKENIZER_PATH\", MODEL_PATH)\n\n\n@st.cache_resource\ndef get_client() -> Client:\n    client = HFClient(MODEL_PATH, TOKENIZER_PATH, PT_PATH)\n    return client\n\n\nclass Client(Protocol):\n    def generate_stream(self,\n                        system: str | None,\n                        tools: list[dict] | None,\n                        history: list[Conversation],\n                        **parameters: Any\n                        ) -> Iterable[TextGenerationStreamResponse]:\n        ...\n\n\ndef stream_chat(\n        self, tokenizer, query: str,\n        history: list[tuple[str, str]] = None,\n        role: str = \"user\",\n        past_key_values=None,\n        max_new_tokens: int = 256,\n        do_sample=True, top_p=0.8,\n        temperature=0.8,\n        repetition_penalty=1.0,\n        length_penalty=1.0, num_beams=1,\n        logits_processor=None,\n        return_past_key_values=False,\n        **kwargs\n):\n    class InvalidScoreLogitsProcessor(LogitsProcessor):\n        def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:\n            if torch.isnan(scores).any() or torch.isinf(scores).any():\n                scores.zero_()\n                scores[..., 5] = 5e4\n            return scores\n\n    if history is None:\n        history = []\n\n    print(\"\\n== Input ==\\n\", query)\n    print(\"\\n==History==\\n\", history)\n\n    if logits_processor is None:\n        logits_processor = LogitsProcessorList()\n    logits_processor.append(InvalidScoreLogitsProcessor())\n    eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command(\"<|user|>\"),\n                    tokenizer.get_command(\"<|observation|>\")]\n    gen_kwargs = {\"max_new_tokens\": max_new_tokens,\n                  \"do_sample\": do_sample,\n                  \"top_p\": top_p,\n                  \"temperature\": temperature,\n                  \"logits_processor\": logits_processor,\n                  \"repetition_penalty\": repetition_penalty,\n                  \"length_penalty\": length_penalty,\n                  \"num_beams\": num_beams,\n                  **kwargs\n                  }\n\n    if past_key_values is None:\n        inputs = tokenizer.build_chat_input(query, history=history, role=role)\n    else:\n        inputs = tokenizer.build_chat_input(query, role=role)\n    inputs = inputs.to(self.device)\n    if past_key_values is not None:\n        past_length = past_key_values[0][0].shape[0]\n        if self.transformer.pre_seq_len is not None:\n            past_length -= self.transformer.pre_seq_len\n        inputs.position_ids += past_length\n        attention_mask = inputs.attention_mask\n        attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)\n        inputs['attention_mask'] = attention_mask\n    history.append({\"role\": role, \"content\": query})\n    input_sequence_length = inputs['input_ids'].shape[1]\n    if input_sequence_length + max_new_tokens >= self.config.seq_length:\n        yield \"Current input sequence length {} plus max_new_tokens {} is too long. The maximum model sequence length is {}. You may adjust the generation parameter to enable longer chat history.\".format(\n            input_sequence_length, max_new_tokens, self.config.seq_length\n        ), history\n        return\n\n    if input_sequence_length > self.config.seq_length:\n        yield \"Current input sequence length {} exceeds maximum model sequence length {}. Unable to generate tokens.\".format(\n            input_sequence_length, self.config.seq_length\n        ), history\n        return\n\n    for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,\n                                        eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,\n                                        **gen_kwargs):\n        if return_past_key_values:\n            outputs, past_key_values = outputs\n        outputs = outputs.tolist()[0][len(inputs[\"input_ids\"][0]):]\n        response = tokenizer.decode(outputs)\n        if response and response[-1] != \"�\":\n            new_history = history\n            if return_past_key_values:\n                yield response, new_history, past_key_values\n            else:\n                yield response, new_history\n\n\nclass HFClient(Client):\n    def __init__(self, model_path: str, tokenizer_path: str, pt_checkpoint: str = None):\n        self.model_path = model_path\n        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)\n\n        if pt_checkpoint is not None and os.path.exists(pt_checkpoint):\n            config = AutoConfig.from_pretrained(\n                model_path,\n                trust_remote_code=True,\n                pre_seq_len=PRE_SEQ_LEN\n            )\n            self.model = AutoModel.from_pretrained(\n                model_path,\n                trust_remote_code=True,\n                config=config,\n                device_map=\"auto\").eval()\n            # add .quantize(bits=4, device=\"cuda\").cuda() before .eval() and remove device_map=\"auto\" to use int4 model\n            # must use cuda to load int4 model\n            prefix_state_dict = torch.load(os.path.join(pt_checkpoint, \"pytorch_model.bin\"))\n            new_prefix_state_dict = {}\n            for k, v in prefix_state_dict.items():\n                if k.startswith(\"transformer.prefix_encoder.\"):\n                    new_prefix_state_dict[k[len(\"transformer.prefix_encoder.\"):]] = v\n            print(\"Loaded from pt checkpoints\", new_prefix_state_dict.keys())\n            self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)\n        else:\n            self.model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map=\"auto\").eval()\n            # add .quantize(bits=4, device=\"cuda\").cuda() before .eval() and remove device_map=\"auto\" to use int4 model\n            # must use cuda to load int4 model\n\n    def generate_stream(\n            self,\n            system: str | None,\n            tools: list[dict] | None,\n            history: list[Conversation],\n            **parameters: Any\n    ) -> Iterable[TextGenerationStreamResponse]:\n        chat_history = [{\n            'role': 'system',\n            'content': system if not tools else TOOL_PROMPT,\n        }]\n\n        if tools:\n            chat_history[0]['tools'] = tools\n\n        for conversation in history[:-1]:\n            chat_history.append({\n                'role': str(conversation.role).removeprefix('<|').removesuffix('|>'),\n                'content': conversation.content,\n            })\n\n        query = history[-1].content\n        role = str(history[-1].role).removeprefix('<|').removesuffix('|>')\n        text = ''\n        for new_text, _ in stream_chat(\n                self.model,\n                self.tokenizer,\n                query,\n                chat_history,\n                role,\n                **parameters,\n        ):\n            word = new_text.removeprefix(text)\n            word_stripped = word.strip()\n            text = new_text\n            yield TextGenerationStreamResponse(\n                generated_text=text,\n                token=Token(\n                    id=0,\n                    logprob=0,\n                    text=word,\n                    special=word_stripped.startswith('<|') and word_stripped.endswith('|>'),\n                )\n            )\n"
  },
  {
    "path": "composite_demo/conversation.py",
    "content": "from dataclasses import dataclass\nfrom enum import auto, Enum\nimport json\n\nfrom PIL.Image import Image\nimport streamlit as st\nfrom streamlit.delta_generator import DeltaGenerator\n\nTOOL_PROMPT = 'Answer the following questions as best as you can. You have access to the following tools:\\n'\n\nclass Role(Enum):\n    SYSTEM = auto()\n    USER = auto()\n    ASSISTANT = auto()\n    TOOL = auto()\n    INTERPRETER = auto()\n    OBSERVATION = auto()\n\n    def __str__(self):\n        match self:\n            case Role.SYSTEM:\n                return \"<|system|>\"\n            case Role.USER:\n                return \"<|user|>\"\n            case Role.ASSISTANT | Role.TOOL | Role.INTERPRETER:\n                return \"<|assistant|>\"\n            case Role.OBSERVATION:\n                return \"<|observation|>\"\n            \n    # Get the message block for the given role\n    def get_message(self):\n        # Compare by value here, because the enum object in the session state\n        # is not the same as the enum cases here, due to streamlit's rerunning\n        # behavior.\n        match self.value:\n            case Role.SYSTEM.value:\n                return\n            case Role.USER.value:\n                return st.chat_message(name=\"user\", avatar=\"user\")\n            case Role.ASSISTANT.value:\n                return st.chat_message(name=\"assistant\", avatar=\"assistant\")\n            case Role.TOOL.value:\n                return st.chat_message(name=\"tool\", avatar=\"assistant\")\n            case Role.INTERPRETER.value:\n                return st.chat_message(name=\"interpreter\", avatar=\"assistant\")\n            case Role.OBSERVATION.value:\n                return st.chat_message(name=\"observation\", avatar=\"user\")\n            case _:\n                st.error(f'Unexpected role: {self}')\n\n@dataclass\nclass Conversation:\n    role: Role\n    content: str\n    tool: str | None = None\n    image: Image | None = None\n\n    def __str__(self) -> str:\n        print(self.role, self.content, self.tool)\n        match self.role:\n            case Role.SYSTEM | Role.USER | Role.ASSISTANT | Role.OBSERVATION:\n                return f'{self.role}\\n{self.content}'\n            case Role.TOOL:\n                return f'{self.role}{self.tool}\\n{self.content}'\n            case Role.INTERPRETER:\n                return f'{self.role}interpreter\\n{self.content}'\n\n    # Human readable format\n    def get_text(self) -> str:\n        text = postprocess_text(self.content)\n        match self.role.value:\n            case Role.TOOL.value:\n                text = f'Calling tool `{self.tool}`:\\n\\n{text}'\n            case Role.INTERPRETER.value:\n                text = f'{text}'\n            case Role.OBSERVATION.value:\n                text = f'Observation:\\n```\\n{text}\\n```'\n        return text\n    \n    # Display as a markdown block\n    def show(self, placeholder: DeltaGenerator | None=None) -> str:\n        if placeholder:\n            message = placeholder\n        else:\n            message = self.role.get_message()\n        if self.image:\n            message.image(self.image)\n        else:\n            text = self.get_text()\n            message.markdown(text)\n\ndef preprocess_text(\n    system: str | None,\n    tools: list[dict] | None,\n    history: list[Conversation],\n) -> str:\n    if tools:\n        tools = json.dumps(tools, indent=4, ensure_ascii=False)\n\n    prompt = f\"{Role.SYSTEM}\\n\"\n    prompt += system if not tools else TOOL_PROMPT\n    if tools:\n        tools = json.loads(tools)\n        prompt += json.dumps(tools, ensure_ascii=False)\n    for conversation in history:\n        prompt += f'{conversation}'\n    prompt += f'{Role.ASSISTANT}\\n'\n    return prompt\n\ndef postprocess_text(text: str) -> str:\n    text = text.replace(\"\\(\", \"$\")\n    text = text.replace(\"\\)\", \"$\")\n    text = text.replace(\"\\[\", \"$$\")\n    text = text.replace(\"\\]\", \"$$\")\n    text = text.replace(\"<|assistant|>\", \"\")\n    text = text.replace(\"<|observation|>\", \"\")\n    text = text.replace(\"<|system|>\", \"\")\n    text = text.replace(\"<|user|>\", \"\")\n    return text.strip()"
  },
  {
    "path": "composite_demo/demo_chat.py",
    "content": "import streamlit as st\nfrom streamlit.delta_generator import DeltaGenerator\n\nfrom client import get_client\nfrom conversation import postprocess_text, preprocess_text, Conversation, Role\n\nclient = get_client()\n\n\n# Append a conversation into history, while show it in a new markdown block\ndef append_conversation(\n        conversation: Conversation,\n        history: list[Conversation],\n        placeholder: DeltaGenerator | None = None,\n) -> None:\n    history.append(conversation)\n    conversation.show(placeholder)\n\n\ndef main(\n        prompt_text: str,\n        system_prompt: str,\n        top_p: float = 0.8,\n        temperature: float = 0.95,\n        repetition_penalty: float = 1.0,\n        max_new_tokens: int = 1024,\n        retry: bool = False\n):\n    placeholder = st.empty()\n    with placeholder.container():\n        if 'chat_history' not in st.session_state:\n            st.session_state.chat_history = []\n\n    if prompt_text == \"\" and retry == False:\n        print(\"\\n== Clean ==\\n\")\n        st.session_state.chat_history = []\n        return\n\n    history: list[Conversation] = st.session_state.chat_history\n    for conversation in history:\n        conversation.show()\n\n    if retry:\n        print(\"\\n== Retry ==\\n\")\n        last_user_conversation_idx = None\n        for idx, conversation in enumerate(history):\n            if conversation.role == Role.USER:\n                last_user_conversation_idx = idx\n        if last_user_conversation_idx is not None:\n            prompt_text = history[last_user_conversation_idx].content\n            del history[last_user_conversation_idx:]\n\n\n    if prompt_text:\n        prompt_text = prompt_text.strip()\n        append_conversation(Conversation(Role.USER, prompt_text), history)\n        placeholder = st.empty()\n        message_placeholder = placeholder.chat_message(name=\"assistant\", avatar=\"assistant\")\n        markdown_placeholder = message_placeholder.empty()\n\n        output_text = ''\n        for response in client.generate_stream(\n                system_prompt,\n                tools=None,\n                history=history,\n                do_sample=True,\n                max_new_tokens=max_new_tokens,\n                temperature=temperature,\n                top_p=top_p,\n                stop_sequences=[str(Role.USER)],\n                repetition_penalty=repetition_penalty,\n        ):\n            token = response.token\n            if response.token.special:\n                print(\"\\n==Output:==\\n\", output_text)\n                match token.text.strip():\n                    case '<|user|>':\n                        break\n                    case _:\n                        st.error(f'Unexpected special token: {token.text.strip()}')\n                        break\n            output_text += response.token.text\n            markdown_placeholder.markdown(postprocess_text(output_text + '▌'))\n\n        append_conversation(Conversation(\n            Role.ASSISTANT,\n            postprocess_text(output_text),\n        ), history, markdown_placeholder)"
  },
  {
    "path": "composite_demo/demo_ci.py",
    "content": "import base64\nfrom io import BytesIO\nimport os\nfrom pprint import pprint\nimport queue\nimport re\nfrom subprocess import PIPE\n\nimport jupyter_client\nfrom PIL import Image\nimport streamlit as st\nfrom streamlit.delta_generator import DeltaGenerator\n\nfrom client import get_client\nfrom conversation import postprocess_text, preprocess_text, Conversation, Role\n\nIPYKERNEL = os.environ.get('IPYKERNEL', 'chatglm3-demo')\n\nSYSTEM_PROMPT = '你是一位智能AI助手，你叫ChatGLM，你连接着一台电脑，但请注意不能联网。在使用Python解决任务时，你可以运行代码并得到结果，如果运行结果有错误，你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件，文件默认存储路径是/mnt/data/。'\n\nclient = get_client()\n\n\nclass CodeKernel(object):\n    def __init__(self,\n                 kernel_name='kernel',\n                 kernel_id=None,\n                 kernel_config_path=\"\",\n                 python_path=None,\n                 ipython_path=None,\n                 init_file_path=\"./startup.py\",\n                 verbose=1):\n\n        self.kernel_name = kernel_name\n        self.kernel_id = kernel_id\n        self.kernel_config_path = kernel_config_path\n        self.python_path = python_path\n        self.ipython_path = ipython_path\n        self.init_file_path = init_file_path\n        self.verbose = verbose\n\n        if python_path is None and ipython_path is None:\n            env = None\n        else:\n            env = {\"PATH\": self.python_path + \":$PATH\", \"PYTHONPATH\": self.python_path}\n\n        # Initialize the backend kernel\n        self.kernel_manager = jupyter_client.KernelManager(kernel_name=IPYKERNEL,\n                                                           connection_file=self.kernel_config_path,\n                                                           exec_files=[self.init_file_path],\n                                                           env=env)\n        if self.kernel_config_path:\n            self.kernel_manager.load_connection_file()\n            self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE)\n            print(\"Backend kernel started with the configuration: {}\".format(\n                self.kernel_config_path))\n        else:\n            self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE)\n            print(\"Backend kernel started with the configuration: {}\".format(\n                self.kernel_manager.connection_file))\n\n        if verbose:\n            pprint(self.kernel_manager.get_connection_info())\n\n        # Initialize the code kernel\n        self.kernel = self.kernel_manager.blocking_client()\n        # self.kernel.load_connection_file()\n        self.kernel.start_channels()\n        print(\"Code kernel started.\")\n\n    def execute(self, code):\n        self.kernel.execute(code)\n        try:\n            shell_msg = self.kernel.get_shell_msg(timeout=30)\n            io_msg_content = self.kernel.get_iopub_msg(timeout=30)['content']\n            while True:\n                msg_out = io_msg_content\n                ### Poll the message\n                try:\n                    io_msg_content = self.kernel.get_iopub_msg(timeout=30)['content']\n                    if 'execution_state' in io_msg_content and io_msg_content['execution_state'] == 'idle':\n                        break\n                except queue.Empty:\n                    break\n\n            return shell_msg, msg_out\n        except Exception as e:\n            print(e)\n            return None\n\n    def execute_interactive(self, code, verbose=False):\n        shell_msg = self.kernel.execute_interactive(code)\n        if shell_msg is queue.Empty:\n            if verbose:\n                print(\"Timeout waiting for shell message.\")\n        self.check_msg(shell_msg, verbose=verbose)\n\n        return shell_msg\n\n    def inspect(self, code, verbose=False):\n        msg_id = self.kernel.inspect(code)\n        shell_msg = self.kernel.get_shell_msg(timeout=30)\n        if shell_msg is queue.Empty:\n            if verbose:\n                print(\"Timeout waiting for shell message.\")\n        self.check_msg(shell_msg, verbose=verbose)\n\n        return shell_msg\n\n    def get_error_msg(self, msg, verbose=False) -> str | None:\n        if msg['content']['status'] == 'error':\n            try:\n                error_msg = msg['content']['traceback']\n            except:\n                try:\n                    error_msg = msg['content']['traceback'][-1].strip()\n                except:\n                    error_msg = \"Traceback Error\"\n            if verbose:\n                print(\"Error: \", error_msg)\n            return error_msg\n        return None\n\n    def check_msg(self, msg, verbose=False):\n        status = msg['content']['status']\n        if status == 'ok':\n            if verbose:\n                print(\"Execution succeeded.\")\n        elif status == 'error':\n            for line in msg['content']['traceback']:\n                if verbose:\n                    print(line)\n\n    def shutdown(self):\n        # Shutdown the backend kernel\n        self.kernel_manager.shutdown_kernel()\n        print(\"Backend kernel shutdown.\")\n        # Shutdown the code kernel\n        self.kernel.shutdown()\n        print(\"Code kernel shutdown.\")\n\n    def restart(self):\n        # Restart the backend kernel\n        self.kernel_manager.restart_kernel()\n        # print(\"Backend kernel restarted.\")\n\n    def interrupt(self):\n        # Interrupt the backend kernel\n        self.kernel_manager.interrupt_kernel()\n        # print(\"Backend kernel interrupted.\")\n\n    def is_alive(self):\n        return self.kernel.is_alive()\n\n\ndef b64_2_img(data):\n    buff = BytesIO(base64.b64decode(data))\n    return Image.open(buff)\n\n\ndef clean_ansi_codes(input_string):\n    ansi_escape = re.compile(r'(\\x9B|\\x1B\\[|\\u001b\\[)[0-?]*[ -/]*[@-~]')\n    return ansi_escape.sub('', input_string)\n\n\ndef execute(code, kernel: CodeKernel) -> tuple[str, str | Image.Image]:\n    res = \"\"\n    res_type = None\n    code = code.replace(\"<|observation|>\", \"\")\n    code = code.replace(\"<|assistant|>interpreter\", \"\")\n    code = code.replace(\"<|assistant|>\", \"\")\n    code = code.replace(\"<|user|>\", \"\")\n    code = code.replace(\"<|system|>\", \"\")\n    msg, output = kernel.execute(code)\n\n    if msg['metadata']['status'] == \"timeout\":\n        return res_type, 'Timed out'\n    elif msg['metadata']['status'] == 'error':\n        return res_type, clean_ansi_codes('\\n'.join(kernel.get_error_msg(msg, verbose=True)))\n\n    if 'text' in output:\n        res_type = \"text\"\n        res = output['text']\n    elif 'data' in output:\n        for key in output['data']:\n            if 'text/plain' in key:\n                res_type = \"text\"\n                res = output['data'][key]\n            elif 'image/png' in key:\n                res_type = \"image\"\n                res = output['data'][key]\n                break\n\n    if res_type == \"image\":\n        return res_type, b64_2_img(res)\n    elif res_type == \"text\" or res_type == \"traceback\":\n        res = res\n\n    return res_type, res\n\n\n@st.cache_resource\ndef get_kernel():\n    kernel = CodeKernel()\n    return kernel\n\n\ndef extract_code(text: str) -> str:\n    pattern = r'```([^\\n]*)\\n(.*?)```'\n    matches = re.findall(pattern, text, re.DOTALL)\n    return matches[-1][1]\n\n\n# Append a conversation into history, while show it in a new markdown block\ndef append_conversation(\n        conversation: Conversation,\n        history: list[Conversation],\n        placeholder: DeltaGenerator | None = None,\n) -> None:\n    history.append(conversation)\n    conversation.show(placeholder)\n\n\ndef main(\n        prompt_text: str,\n        top_p: float = 0.2,\n        temperature: float = 0.1,\n        repetition_penalty: float = 1.1,\n        max_new_tokens: int = 1024,\n        truncate_length: int = 1024,\n        retry: bool = False\n):\n    if 'ci_history' not in st.session_state:\n        st.session_state.ci_history = []\n\n\n    if prompt_text == \"\" and retry == False:\n        print(\"\\n== Clean ==\\n\")\n        st.session_state.chat_history = []\n        return\n\n    history: list[Conversation] = st.session_state.chat_history\n    for conversation in history:\n        conversation.show()\n\n    if retry:\n        print(\"\\n== Retry ==\\n\")\n        last_user_conversation_idx = None\n        for idx, conversation in enumerate(history):\n            if conversation.role == Role.USER:\n                last_user_conversation_idx = idx\n        if last_user_conversation_idx is not None:\n            prompt_text = history[last_user_conversation_idx].content\n            del history[last_user_conversation_idx:]\n    if prompt_text:\n        prompt_text = prompt_text.strip()\n        role = Role.USER\n        append_conversation(Conversation(role, prompt_text), history)\n\n        placeholder = st.container()\n        message_placeholder = placeholder.chat_message(name=\"assistant\", avatar=\"assistant\")\n        markdown_placeholder = message_placeholder.empty()\n\n        for _ in range(5):\n            output_text = ''\n            for response in client.generate_stream(\n                    system=SYSTEM_PROMPT,\n                    tools=None,\n                    history=history,\n                    do_sample=True,\n                    max_new_token=max_new_tokens,\n                    temperature=temperature,\n                    top_p=top_p,\n                    stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)],\n                    repetition_penalty=repetition_penalty,\n            ):\n                token = response.token\n                if response.token.special:\n                    print(\"\\n==Output:==\\n\", output_text)\n                    match token.text.strip():\n                        case '<|user|>':\n                            append_conversation(Conversation(\n                                Role.ASSISTANT,\n                                postprocess_text(output_text),\n                            ), history, markdown_placeholder)\n                            return\n                        # Initiate tool call\n                        case '<|assistant|>':\n                            append_conversation(Conversation(\n                                Role.ASSISTANT,\n                                postprocess_text(output_text),\n                            ), history, markdown_placeholder)\n                            message_placeholder = placeholder.chat_message(name=\"interpreter\", avatar=\"assistant\")\n                            markdown_placeholder = message_placeholder.empty()\n                            output_text = ''\n                            continue\n                        case '<|observation|>':\n                            code = extract_code(output_text)\n\n                            display_text = output_text.split('interpreter')[-1].strip()\n                            append_conversation(Conversation(\n                                Role.INTERPRETER,\n                                postprocess_text(display_text),\n                            ), history, markdown_placeholder)\n                            message_placeholder = placeholder.chat_message(name=\"observation\", avatar=\"user\")\n                            markdown_placeholder = message_placeholder.empty()\n                            output_text = ''\n\n                            with markdown_placeholder:\n                                with st.spinner('Executing code...'):\n                                    try:\n                                        res_type, res = execute(code, get_kernel())\n                                    except Exception as e:\n                                        st.error(f'Error when executing code: {e}')\n                                        return\n                            print(\"Received:\", res_type, res)\n                            if truncate_length:\n                                if res_type == 'text' and len(res) > truncate_length:\n                                    res = res[:truncate_length] + ' [TRUNCATED]'\n\n                            append_conversation(Conversation(\n                                Role.OBSERVATION,\n                                '[Image]' if res_type == 'image' else postprocess_text(res),\n                                tool=None,\n                                image=res if res_type == 'image' else None,\n                            ), history, markdown_placeholder)\n                            message_placeholder = placeholder.chat_message(name=\"assistant\", avatar=\"assistant\")\n                            markdown_placeholder = message_placeholder.empty()\n                            output_text = ''\n                            break\n                        case _:\n                            st.error(f'Unexpected special token: {token.text.strip()}')\n                            break\n                output_text += response.token.text\n                display_text = output_text.split('interpreter')[-1].strip()\n                markdown_placeholder.markdown(postprocess_text(display_text + '▌'))\n            else:\n                append_conversation(Conversation(\n                    Role.ASSISTANT,\n                    postprocess_text(output_text),\n                ), history, markdown_placeholder)\n                return\n    else:\n        st.session_state.chat_history = []\n"
  },
  {
    "path": "composite_demo/demo_tool.py",
    "content": "import re\nimport yaml\nfrom yaml import YAMLError\n\nimport streamlit as st\nfrom streamlit.delta_generator import DeltaGenerator\n\nfrom client import get_client\nfrom conversation import postprocess_text, preprocess_text, Conversation, Role\nfrom tool_registry import dispatch_tool, get_tools\n\nEXAMPLE_TOOL = {\n    \"name\": \"get_current_weather\",\n    \"description\": \"Get the current weather in a given location\",\n    \"parameters\": {\n        \"type\": \"object\",\n        \"properties\": {\n            \"location\": {\n                \"type\": \"string\",\n                \"description\": \"The city and state, e.g. San Francisco, CA\",\n            },\n            \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n        },\n        \"required\": [\"location\"],\n    }\n}\n\nclient = get_client()\n\n\ndef tool_call(*args, **kwargs) -> dict:\n    print(\"=== Tool call===\")\n    print(args)\n    print(kwargs)\n    st.session_state.calling_tool = True\n    return kwargs\n\n\ndef yaml_to_dict(tools: str) -> list[dict] | None:\n    try:\n        return yaml.safe_load(tools)\n    except YAMLError:\n        return None\n\n\ndef extract_code(text: str) -> str:\n    pattern = r'```([^\\n]*)\\n(.*?)```'\n    matches = re.findall(pattern, text, re.DOTALL)\n    print(matches)\n    return matches[-1][1]\n\n\n# Append a conversation into history, while show it in a new markdown block\ndef append_conversation(\n        conversation: Conversation,\n        history: list[Conversation],\n        placeholder: DeltaGenerator | None = None,\n) -> None:\n    history.append(conversation)\n    conversation.show(placeholder)\n\n\ndef main(\n        prompt_text: str,\n        top_p: float = 0.2,\n        temperature: float = 0.1,\n        repetition_penalty: float = 1.1,\n        max_new_tokens: int = 1024,\n        truncate_length: int = 1024,\n        retry: bool = False\n):\n    manual_mode = st.toggle('Manual mode',\n                            help='Define your tools in YAML format. You need to supply tool call results manually.'\n                            )\n\n    if manual_mode:\n        with st.expander('Tools'):\n            tools = st.text_area(\n                'Define your tools in YAML format here:',\n                yaml.safe_dump([EXAMPLE_TOOL], sort_keys=False),\n                height=400,\n            )\n        tools = yaml_to_dict(tools)\n\n        if not tools:\n            st.error('YAML format error in tools definition')\n    else:\n        tools = get_tools()\n\n    if 'tool_history' not in st.session_state:\n        st.session_state.tool_history = []\n    if 'calling_tool' not in st.session_state:\n        st.session_state.calling_tool = False\n\n    if 'chat_history' not in st.session_state:\n        st.session_state.chat_history = []\n\n    if prompt_text == \"\" and retry == False:\n        print(\"\\n== Clean ==\\n\")\n        st.session_state.chat_history = []\n        return\n\n    history: list[Conversation] = st.session_state.chat_history\n    for conversation in history:\n        conversation.show()\n\n    if retry:\n        print(\"\\n== Retry ==\\n\")\n        last_user_conversation_idx = None\n        for idx, conversation in enumerate(history):\n            if conversation.role == Role.USER:\n                last_user_conversation_idx = idx\n        if last_user_conversation_idx is not None:\n            prompt_text = history[last_user_conversation_idx].content\n            del history[last_user_conversation_idx:]\n\n    if prompt_text:\n        prompt_text = prompt_text.strip()\n        role = st.session_state.calling_tool and Role.OBSERVATION or Role.USER\n        append_conversation(Conversation(role, prompt_text), history)\n        st.session_state.calling_tool = False\n\n        placeholder = st.container()\n        message_placeholder = placeholder.chat_message(name=\"assistant\", avatar=\"assistant\")\n        markdown_placeholder = message_placeholder.empty()\n\n        for _ in range(5):\n            output_text = ''\n            for response in client.generate_stream(\n                    system=None,\n                    tools=tools,\n                    history=history,\n                    do_sample=True,\n                    max_new_tokens=max_new_tokens,\n                    temperature=temperature,\n                    top_p=top_p,\n                    stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)],\n                    repetition_penalty=repetition_penalty,\n            ):\n                token = response.token\n                if response.token.special:\n                    print(\"\\n==Output:==\\n\", output_text)\n                    match token.text.strip():\n                        case '<|user|>':\n                            append_conversation(Conversation(\n                                Role.ASSISTANT,\n                                postprocess_text(output_text),\n                            ), history, markdown_placeholder)\n                            return\n                        # Initiate tool call\n                        case '<|assistant|>':\n                            append_conversation(Conversation(\n                                Role.ASSISTANT,\n                                postprocess_text(output_text),\n                            ), history, markdown_placeholder)\n                            output_text = ''\n                            message_placeholder = placeholder.chat_message(name=\"tool\", avatar=\"assistant\")\n                            markdown_placeholder = message_placeholder.empty()\n                            continue\n                        case '<|observation|>':\n                            tool, *call_args_text = output_text.strip().split('\\n')\n                            call_args_text = '\\n'.join(call_args_text)\n\n                            append_conversation(Conversation(\n                                Role.TOOL,\n                                postprocess_text(output_text),\n                                tool,\n                            ), history, markdown_placeholder)\n                            message_placeholder = placeholder.chat_message(name=\"observation\", avatar=\"user\")\n                            markdown_placeholder = message_placeholder.empty()\n\n                            try:\n                                code = extract_code(call_args_text)\n                                args = eval(code, {'tool_call': tool_call}, {})\n                            except:\n                                st.error('Failed to parse tool call')\n                                return\n\n                            output_text = ''\n\n                            if manual_mode:\n                                st.info('Please provide tool call results below:')\n                                return\n                            else:\n                                with markdown_placeholder:\n                                    with st.spinner(f'Calling tool {tool}...'):\n                                        observation = dispatch_tool(tool, args)\n\n                                if len(observation) > truncate_length:\n                                    observation = observation[:truncate_length] + ' [TRUNCATED]'\n                                append_conversation(Conversation(\n                                    Role.OBSERVATION, observation\n                                ), history, markdown_placeholder)\n                                message_placeholder = placeholder.chat_message(name=\"assistant\", avatar=\"assistant\")\n                                markdown_placeholder = message_placeholder.empty()\n                                st.session_state.calling_tool = False\n                                break\n                        case _:\n                            st.error(f'Unexpected special token: {token.text.strip()}')\n                            return\n                output_text += response.token.text\n                markdown_placeholder.markdown(postprocess_text(output_text + '▌'))\n            else:\n                append_conversation(Conversation(\n                    Role.ASSISTANT,\n                    postprocess_text(output_text),\n                ), history, markdown_placeholder)\n                return\n"
  },
  {
    "path": "composite_demo/main.py",
    "content": "import streamlit as st\nst.set_page_config(\n    page_title=\"ChatGLM3 Demo\",\n    page_icon=\":robot:\",\n    layout='centered',\n    initial_sidebar_state='expanded',\n)\n\n\nimport demo_chat, demo_ci, demo_tool\nfrom enum import Enum\n\nDEFAULT_SYSTEM_PROMPT = '''\nYou are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.\n'''.strip()\n\n# Set the title of the demo\nst.title(\"ChatGLM3 Demo\")\n\n# Add your custom text here, with smaller font size\nst.markdown(\n    \"<sub>智谱AI 公开在线技术文档: https://lslfd0slxc.feishu.cn/wiki/WvQbwIJ9tiPAxGk8ywDck6yfnof </sub> \\n\\n <sub> 更多 ChatGLM3-6B 的使用方法请参考文档。</sub>\",\n    unsafe_allow_html=True)\n\n\nclass Mode(str, Enum):\n    CHAT, TOOL, CI = '💬 Chat', '🛠️ Tool', '🧑‍💻 Code Interpreter'\n\n\nwith st.sidebar:\n    top_p = st.slider(\n        'top_p', 0.0, 1.0, 0.8, step=0.01\n    )\n    temperature = st.slider(\n        'temperature', 0.0, 1.5, 0.95, step=0.01\n    )\n    repetition_penalty = st.slider(\n        'repetition_penalty', 0.0, 2.0, 1.1, step=0.01\n    )\n    max_new_token = st.slider(\n        'Output length', 5, 32000, 256, step=1\n    )\n\n    cols = st.columns(2)\n    export_btn = cols[0]\n    clear_history = cols[1].button(\"Clear History\", use_container_width=True)\n    retry = export_btn.button(\"Retry\", use_container_width=True)\n\n    system_prompt = st.text_area(\n        label=\"System Prompt (Only for chat mode)\",\n        height=300,\n        value=DEFAULT_SYSTEM_PROMPT,\n    )\n\nprompt_text = st.chat_input(\n    'Chat with ChatGLM3!',\n    key='chat_input',\n)\n\ntab = st.radio(\n    'Mode',\n    [mode.value for mode in Mode],\n    horizontal=True,\n    label_visibility='hidden',\n)\n\nif clear_history or retry:\n    prompt_text = \"\"\n\nmatch tab:\n    case Mode.CHAT:\n        demo_chat.main(\n            retry=retry,\n            top_p=top_p,\n            temperature=temperature,\n            prompt_text=prompt_text,\n            system_prompt=system_prompt,\n            repetition_penalty=repetition_penalty,\n            max_new_tokens=max_new_token\n        )\n    case Mode.TOOL:\n        demo_tool.main(\n            retry=retry,\n            top_p=top_p,\n            temperature=temperature,\n            prompt_text=prompt_text,\n            repetition_penalty=repetition_penalty,\n            max_new_tokens=max_new_token,\n            truncate_length=1024)\n    case Mode.CI:\n        demo_ci.main(\n            retry=retry,\n            top_p=top_p,\n            temperature=temperature,\n            prompt_text=prompt_text,\n            repetition_penalty=repetition_penalty,\n            max_new_tokens=max_new_token,\n            truncate_length=1024)\n    case _:\n        st.error(f'Unexpected tab: {tab}')\n"
  },
  {
    "path": "composite_demo/requirements.txt",
    "content": "huggingface_hub>=0.19.4\npillow>=10.1.0\npyyaml>=6.0.1\nrequests>=2.31.0\nipykernel>=6.26.0\nipython>=8.18.1\njupyter_client>=8.6.0\n"
  },
  {
    "path": "composite_demo/tool_registry.py",
    "content": "\"\"\"\nThis code is the tool registration part. By registering the tool, the model can call the tool.\nThis code provides extended functionality to the model, enabling it to call and interact with a variety of utilities\nthrough defined interfaces.\n\"\"\"\n\nimport copy\nimport inspect\nfrom pprint import pformat\nimport traceback\nfrom types import GenericAlias\nfrom typing import get_origin, Annotated\nimport subprocess\n\n_TOOL_HOOKS = {}\n_TOOL_DESCRIPTIONS = {}\n\n\ndef register_tool(func: callable):\n    tool_name = func.__name__\n    tool_description = inspect.getdoc(func).strip()\n    python_params = inspect.signature(func).parameters\n    tool_params = []\n    for name, param in python_params.items():\n        annotation = param.annotation\n        if annotation is inspect.Parameter.empty:\n            raise TypeError(f\"Parameter `{name}` missing type annotation\")\n        if get_origin(annotation) != Annotated:\n            raise TypeError(f\"Annotation type for `{name}` must be typing.Annotated\")\n\n        typ, (description, required) = annotation.__origin__, annotation.__metadata__\n        typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__\n        if not isinstance(description, str):\n            raise TypeError(f\"Description for `{name}` must be a string\")\n        if not isinstance(required, bool):\n            raise TypeError(f\"Required for `{name}` must be a bool\")\n\n        tool_params.append({\n            \"name\": name,\n            \"description\": description,\n            \"type\": typ,\n            \"required\": required\n        })\n    tool_def = {\n        \"name\": tool_name,\n        \"description\": tool_description,\n        \"params\": tool_params\n    }\n    print(\"[registered tool] \" + pformat(tool_def))\n    _TOOL_HOOKS[tool_name] = func\n    _TOOL_DESCRIPTIONS[tool_name] = tool_def\n\n    return func\n\n\ndef dispatch_tool(tool_name: str, tool_params: dict) -> str:\n    if tool_name not in _TOOL_HOOKS:\n        return f\"Tool `{tool_name}` not found. Please use a provided tool.\"\n    tool_call = _TOOL_HOOKS[tool_name]\n    try:\n        ret = tool_call(**tool_params)\n    except:\n        ret = traceback.format_exc()\n    return str(ret)\n\n\ndef get_tools() -> dict:\n    return copy.deepcopy(_TOOL_DESCRIPTIONS)\n\n\n# Tool Definitions\n\n@register_tool\ndef random_number_generator(\n        seed: Annotated[int, 'The random seed used by the generator', True],\n        range: Annotated[tuple[int, int], 'The range of the generated numbers', True],\n) -> int:\n    \"\"\"\n    Generates a random number x, s.t. range[0] <= x < range[1]\n    \"\"\"\n    if not isinstance(seed, int):\n        raise TypeError(\"Seed must be an integer\")\n    if not isinstance(range, tuple):\n        raise TypeError(\"Range must be a tuple\")\n    if not isinstance(range[0], int) or not isinstance(range[1], int):\n        raise TypeError(\"Range must be a tuple of integers\")\n\n    import random\n    return random.Random(seed).randint(*range)\n\n\n@register_tool\ndef get_weather(\n        city_name: Annotated[str, 'The name of the city to be queried', True],\n) -> str:\n    \"\"\"\n    Get the current weather for `city_name`\n    \"\"\"\n\n    if not isinstance(city_name, str):\n        raise TypeError(\"City name must be a string\")\n\n    key_selection = {\n        \"current_condition\": [\"temp_C\", \"FeelsLikeC\", \"humidity\", \"weatherDesc\", \"observation_time\"],\n    }\n    import requests\n    try:\n        resp = requests.get(f\"https://wttr.in/{city_name}?format=j1\")\n        resp.raise_for_status()\n        resp = resp.json()\n        ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()}\n    except:\n        import traceback\n        ret = \"Error encountered while fetching weather data!\\n\" + traceback.format_exc()\n\n    return str(ret)\n\n\n@register_tool\ndef get_shell(\n        query: Annotated[str, 'The command should run in Linux shell', True],\n) -> str:\n    \"\"\"\n       Use shell to run command\n    \"\"\"\n    if not isinstance(query, str):\n        raise TypeError(\"Command must be a string\")\n    try:\n        result = subprocess.run(query, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE,\n                                text=True)\n        return result.stdout\n    except subprocess.CalledProcessError as e:\n        return e.stderr\n\n\nif __name__ == \"__main__\":\n    # print(dispatch_tool(\"get_shell\", {\"query\": \"pwd\"}))\n    print(get_tools())"
  },
  {
    "path": "finetune_demo/README.md",
    "content": "# ChatGLM3-6B 微调\n\n本目录提供 ChatGLM3-6B 模型的微调示例，包括全量微调和 P-Tuning v2。格式上，提供多轮对话微调样例和输入输出格式微调样例。\n\n如果将模型下载到了本地，本文和代码中的 `THUDM/chatglm3-6b` 字段均应替换为相应地址以从本地加载模型。\n\n运行示例需要 `python>=3.10`，除基础的 `torch` 依赖外，示例代码运行还需要依赖。\n\n**我们提供了 [示例notebook](lora_finetune.ipynb) 用于演示如何使用我们的微调代码。**\n\n```bash\npip install -r requirements.txt\n```\n\n## 测试硬件标准\n\n我们仅提供了单机多卡/多机多卡的运行示例，因此您需要至少一台具有多个 GPU 的机器。本仓库中的**默认配置文件**中，我们记录了显存的占用情况：\n\n+ SFT 全量微调: 4张显卡平均分配，每张显卡占用 `48346MiB` 显存。\n+ P-TuningV2 微调: 1张显卡，占用 `18426MiB` 显存。\n+ LORA 微调: 1张显卡，占用 `14082MiB` 显存。\n\n> 请注意，该结果仅供参考，对于不同的参数，显存占用可能会有所不同。请结合你的硬件情况进行调整。\n\n> 请注意，我们仅仅使用英伟达 Hopper(代表显卡：H100) 和 Ampère(代表显卡:A100) 架构和系列显卡做过测试。如果您使用其他架构的显卡，可能会出现\n> 1. 未知的训练问题 / 显存占用与上述有误差。\n> 2. 架构过低而不支持某些特性。\n> 3. 推理效果问题。\n     > 以上三种情况为社区曾经遇到过的问题，虽然概率较低，如果您遇到了以上问题，可以尝试在社区中解决。\n\n## 多轮对话格式\n\n多轮对话微调示例采用 ChatGLM3 对话格式约定，对不同角色添加不同 `loss_mask` 从而在一遍计算中为多轮回复计算 `loss`。\n\n对于数据文件，样例采用如下格式\n\n如果您仅希望微调模型的对话能力，而非工具能力，您应该按照以下格式整理数据。\n\n```json\n[\n  {\n    \"conversations\": [\n      {\n        \"role\": \"system\",\n        \"content\": \"<system prompt text>\"\n      },\n      {\n        \"role\": \"user\",\n        \"content\": \"<user prompt text>\"\n      },\n      {\n        \"role\": \"assistant\",\n        \"content\": \"<assistant response text>\"\n      },\n      // ... Muti Turn\n      {\n        \"role\": \"user\",\n        \"content\": \"<user prompt text>\"\n      },\n      {\n        \"role\": \"assistant\",\n        \"content\": \"<assistant response text>\"\n      }\n    ]\n  }\n  // ...\n]\n```\n\n**请注意，这种方法在微调的step较多的情况下会影响到模型的工具调用功能**\n\n如果您希望微调模型的对话和工具能力，您应该按照以下格式整理数据。\n\n```json\n[\n  {\n    \"tools\": [\n      // available tools, format is not restricted\n    ],\n    \"conversations\": [\n      {\n        \"role\": \"system\",\n        \"content\": \"<system prompt text>\"\n      },\n      {\n        \"role\": \"user\",\n        \"content\": \"<user prompt text>\"\n      },\n      {\n        \"role\": \"assistant\",\n        \"content\": \"<assistant thought to text>\"\n      },\n      {\n        \"role\": \"tool\",\n        \"name\": \"<name of the tool to be called\",\n        \"parameters\": {\n          \"<parameter_name>\": \"<parameter_value>\"\n        },\n        \"observation\": \"<observation>\"\n        // don't have to be string\n      },\n      {\n        \"role\": \"assistant\",\n        \"content\": \"<assistant response to observation>\"\n      },\n      // ... Muti Turn\n      {\n        \"role\": \"user\",\n        \"content\": \"<user prompt text>\"\n      },\n      {\n        \"role\": \"assistant\",\n        \"content\": \"<assistant response text>\"\n      }\n    ]\n  }\n  // ...\n]\n```\n\n- 关于工具描述的 system prompt 无需手动插入，预处理时会将 `tools` 字段使用 `json.dumps(..., ensure_ascii=False)`\n  格式化后插入为首条 system prompt。\n\n- 每种角色可以附带一个 `bool` 类型的 `loss` 字段，表示该字段所预测的内容是否参与 `loss`\n  计算。若没有该字段，样例实现中默认对 `system`, `user` 不计算 `loss`，其余角色则计算 `loss`。\n\n- `tool` 并不是 ChatGLM3 中的原生角色，这里的 `tool` 在预处理阶段将被自动转化为一个具有工具调用 `metadata` 的 `assistant`\n  角色（默认计算 `loss`）和一个表示工具返回值的 `observation` 角色（不计算 `loss`）。\n\n- 目前暂未实现 `Code interpreter` 的微调任务。\n\n- `system` 角色为可选角色，但若存在 `system` 角色，其必须出现在 `user`\n  角色之前，且一个完整的对话数据（无论单轮或者多轮对话）只能出现一次 `system` 角色。\n\n## 数据集格式示例\n\n这里以 AdvertiseGen 数据集为例,\n您可以从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing)\n或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载 AdvertiseGen 数据集。\n将解压后的 AdvertiseGen 目录放到 `data` 目录下并自行转换为如下格式数据集。\n\n> 请注意，现在的微调代码中加入了验证集，因此，对于一组完整的微调数据集，必须包含训练数据集和验证数据集，测试数据集可以不填写。或者直接用验证数据集代替。\n\n```\n{\"conversations\": [{\"role\": \"user\", \"content\": \"类型#裙*裙长#半身裙\"}, {\"role\": \"assistant\", \"content\": \"这款百搭时尚的仙女半身裙，整体设计非常的飘逸随性，穿上之后每个女孩子都能瞬间变成小仙女啦。料子非常的轻盈，透气性也很好，穿到夏天也很舒适。\"}]}\n```\n\n## 配置文件\n\n微调配置文件位于 `config` 目录下，包括以下文件：\n\n1. `ds_zereo_2 / ds_zereo_3.json`: deepspeed 配置文件。\n2. `lora.yaml / ptuning.yaml / sft.yaml`: 模型不同方式的配置文件，包括模型参数、优化器参数、训练参数等。 部分重要参数解释如下：\n    + data_config 部分\n        + train_file: 训练数据集的文件路径。\n        + val_file: 验证数据集的文件路径。\n        + test_file: 测试数据集的文件路径。\n        + num_proc: 在加载数据时使用的进程数量。\n    + max_input_length: 输入序列的最大长度。\n    + max_output_length: 输出序列的最大长度。\n    + training_args 部分\n        + output_dir: 用于保存模型和其他输出的目录。\n        + max_steps: 训练的最大步数。\n        + per_device_train_batch_size: 每个设备（如 GPU）的训练批次大小。\n        + dataloader_num_workers: 加载数据时使用的工作线程数量。\n        + remove_unused_columns: 是否移除数据中未使用的列。\n        + save_strategy: 模型保存策略（例如，每隔多少步保存一次）。\n        + save_steps: 每隔多少步保存一次模型。\n        + log_level: 日志级别（如 info）。\n        + logging_strategy: 日志记录策略。\n        + logging_steps: 每隔多少步记录一次日志。\n        + per_device_eval_batch_size: 每个设备的评估批次大小。\n        + evaluation_strategy: 评估策略（例如，每隔多少步进行一次评估）。\n        + eval_steps: 每隔多少步进行一次评估。\n        + predict_with_generate: 是否使用生成模式进行预测。\n    + generation_config 部分\n        + max_new_tokens: 生成的最大新 token 数量。\n    + peft_config 部分\n        + peft_type: 使用的参数有效调整类型（如 LORA）。\n        + task_type: 任务类型，这里是因果语言模型（CAUSAL_LM）。\n    + Lora 参数：\n        + r: LoRA 的秩。\n        + lora_alpha: LoRA 的缩放因子。\n        + lora_dropout: 在 LoRA 层使用的 dropout 概率\n    + P-TuningV2 参数：\n        + num_virtual_tokens: 虚拟 token 的数量。\n\n## 开始微调\n\n通过以下代码执行 **单机多卡/多机多卡** 运行，这是使用 `deepspeed` 作为加速方案的，您需要安装 `deepspeed`。\n\n```angular2html\ncd finetune_demo\nOMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8  finetune_hf.py  data/AdvertiseGen/  THUDM/chatglm3-6b  configs/lora.yaml\n```\n\n通过以下代码执行 **单机单卡** 运行。\n\n```angular2html\ncd finetune_demo\npython finetune_hf.py  data/AdvertiseGen/  THUDM/chatglm3-6b  configs/lora.yaml\n```\n\n## 从保存点进行微调\n\n如果按照上述方式进行训练，每次微调都会从头开始，如果你想从训练一半的模型开始微调，你可以加入第四个参数，这个参数有两种传入方式:\n\n1. `yes`, 自动从最后一个保存的 Checkpoint开始训练\n2. `XX`, 断点号数字 例 `600` 则从序号600 Checkpoint开始训练\n\n例如，这就是一个从最后一个保存点继续微调的示例代码\n```angular2html\ncd finetune_demo\npython finetune_hf.py  data/AdvertiseGen/  THUDM/chatglm3-6b  configs/lora.yaml yes\n```\n\n\n## 使用微调后的模型\n\n### 在 inference_hf.py 中验证微调后的模型\n\n您可以在 `finetune_demo/inference_hf.py` 中使用我们的微调后的模型，仅需要一行代码就能简单的进行测试。\n\n```angular2html\npython inference_hf.py your_finetune_path --prompt your prompt\n```\n\n这样，得到的回答就微调后的回答了。\n\n### 在本仓库的其他 demo 或者外部仓库使用微调后的模型\n\n您可以在任何一个 demo 内使用我们的 `lora` 和 全参微调的模型。这需要你自己按照以下教程进行修改代码。\n\n1. 使用`finetune_demo/inference_hf.py`中读入模型的方式替换 demo 中读入模型的方式。\n\n> 请注意，对于 LORA 和 P-TuningV2 我们没有合并训练后的模型，而是在`adapter_config.json`\n> 中记录了微调型的路径，如果你的原始模型位置发生更改，则你应该修改`adapter_config.json`中`base_model_name_or_path`的路径。\n\n```python\ndef load_model_and_tokenizer(\n        model_dir: Union[str, Path], trust_remote_code: bool = True\n) -> tuple[ModelType, TokenizerType]:\n    model_dir = _resolve_path(model_dir)\n    if (model_dir / 'adapter_config.json').exists():\n        model = AutoPeftModelForCausalLM.from_pretrained(\n            model_dir, trust_remote_code=trust_remote_code, device_map='auto'\n        )\n        tokenizer_dir = model.peft_config['default'].base_model_name_or_path\n    else:\n        model = AutoModelForCausalLM.from_pretrained(\n            model_dir, trust_remote_code=trust_remote_code, device_map='auto'\n        )\n        tokenizer_dir = model_dir\n    tokenizer = AutoTokenizer.from_pretrained(\n        tokenizer_dir, trust_remote_code=trust_remote_code\n    )\n    return model, tokenizer\n```\n\n2. 读取微调的模型，请注意，你应该使用微调模型的位置，例如，若你的模型位置为`/path/to/finetune_adapter_model`\n   ，原始模型地址为`path/to/base_model`,则你应该使用`/path/to/finetune_adapter_model`作为`model_dir`。\n3. 完成上述操作后，就能正常使用微调的模型了，其他的调用方式没有变化。\n\n### 提示\n\n1. 微调代码在开始训练前，会先打印首条训练数据的预处理信息(默认已经注释，可以解除注释)，显示为\n\n```log\nSanity\nCheck >> >> >> >> >> >> >\n'[gMASK]': 64790 ->   -100\n'sop': 64792 ->   -100\n'<|system|>': 64794 ->   -100\n'': 30910 ->   -100\n'\\n': 13 ->   -100\n'Answer': 20115 ->   -100\n'the': 267 ->   -100\n'following': 1762 ->   -100\n...\n'know': 683 ->   -100\n'the': 267 ->   -100\n'response': 3010 ->   -100\n'details': 3296 ->   -100\n'.': 30930 ->   -100\n'<|assistant|>': 64796 ->   -100\n'': 30910 ->  30910\n'\\n': 13 ->     13\n'I': 307 ->    307\n'need': 720 ->    720\n'to': 289 ->    289\n'use': 792 ->    792\n...\n<< << << << << << < Sanity\nCheck\n```\n\n字样，每行依次表示一个 detokenized string, token_id 和 target_id。其中，`target_id`为`token_id`在模型词表中的索引，`-100`表示该\ntoken 不参与 `loss` 计算。\n\n2. `_prepare_model_for_training` 的作用是遍历模型的所有可训练参数，并确保它们的数据类型为`torch.float32`。\n   这在某些情况下是必要的，因为混合精度训练或其他操作可能会更改模型参数的数据类型。该代码默打开，可以注释，但是如果使用\n   `half` 格式训练出现问题，可以切换回这个代码，显存可能增加。\n3. 在我们的[Huggingface模型代码](https://huggingface.co/THUDM/chatglm3-6b/blob/main/modeling_chatglm.py)中，有以下内容：\n    ```python\n   if self.gradient_checkpointing and self.training:\n                layer_ret = torch.utils.checkpoint.checkpoint(\n                    layer,\n                    hidden_states,\n                    attention_mask,\n                    rotary_pos_emb,\n                    kv_caches[index],\n                    use_cache,\n                    use_reentrant=False\n                )\n   ```\n   这可能导致训练的时候显存增加，因此，如果您的显存不足，可以尝试将``` use_reentrant``` 修改为`True`。\n4. 微调后的模型可以使用任何支持 `peft` 载入的模型加速框架，在这里，我们没有提供demo。\n5. 本仓库的微调数据集格式与 API 微调数据集格式有一定区别\n    + ZhipuAI API 微调数据集中的 `messages` 字段在本仓库为 `conversation` 字段。\n    + ZhipuAI API 中的微调文件为 `jsonl`, 在本仓库，需要简单的将文件名改为 `json`。\n\n## 参考文献\n\n```\n\n@inproceedings{liu2022p,\ntitle={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},\nauthor={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},\nbooktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short\nPapers)},\npages={61--68},\nyear={2022}\n}\n\n@misc{tang2023toolalpaca,\ntitle={ToolAlpaca: Generalized Tool Learning for Language Models with 3000 Simulated Cases},\nauthor={Qiaoyu Tang and Ziliang Deng and Hongyu Lin and Xianpei Han and Qiao Liang and Le Sun},\nyear={2023},\neprint={2306.05301},\narchivePrefix={arXiv},\nprimaryClass={cs.CL}\n}\n\n```"
  },
  {
    "path": "finetune_demo/README_en.md",
    "content": "# ChatGLM3-6B finetune\n\nThis directory provides fine-tuning examples of the ChatGLM3-6B model, including full fine-tuning and P-Tuning v2. In\nterms of format, it provides multiple rounds of dialogue fine-tuning samples and input and output format fine-tuning\nsamples.\n\nIf the model is downloaded locally, the `THUDM/chatglm3-6b` field in this article and the code should be replaced with\nthe corresponding address to load the model locally.\n\nRunning the example requires `python>=3.10`. In addition to the basic `torch` dependency, the example code also requires\ndependencies to run.\n\n**We provide [sample notebook](lora_finetune.ipynb) to demonstrate how to use our fine-tuning code. **\n\n```bash\npip install -r requirements.txt\n```\n\n## Test hardware standards\n\nWe only provide single-machine multi-card/multi-machine multi-card running examples, so you will need at least one\nmachine with multiple GPUs. In the **default configuration file** in this warehouse, we record the usage of video\nmemory:\n\n+ SFT full fine-tuning: evenly distributed among 4 graphics cards, each graphics card occupies `48346MiB` of video\n  memory.\n+ P-TuningV2 fine-tuning: 1 graphics card, occupying `18426MiB` memory.\n+ LORA fine-tuning: 1 graphics card, occupying `14082MiB` memory.\n\n> Please note that this result is for reference only, and the memory usage may be different for different parameters.\n> Please make adjustments based on your hardware conditions.\n\n## Multi-turn dialogue format\n\nThe multi-round dialogue fine-tuning example adopts the ChatGLM3 dialogue format convention and adds\ndifferent `loss_mask` to different characters to calculate `loss` for multiple rounds of responses in one pass.\n\nFor data files, the sample adopts the following format\n\nIf you only want to fine-tune your model's conversational capabilities, rather than its tool capabilities, you should\norganize your data in the following format.\n\n```json\n[\n  {\n    \"conversations\": [\n      {\n        \"role\": \"system\",\n        \"content\": \"<system prompt text>\"\n      },\n      {\n        \"role\": \"user\",\n        \"content\": \"<user prompt text>\"\n      },\n      {\n        \"role\": \"assistant\",\n        \"content\": \"<assistant response text>\"\n      },\n      // ... Muti Turn\n      {\n        \"role\": \"user\",\n        \"content\": \"<user prompt text>\"\n      },\n      {\n        \"role\": \"assistant\",\n        \"content\": \"<assistant response text>\"\n      }\n    ]\n  }\n  // ...\n]\n```\n\n**Please note that this method will affect the tool calling function of the model when there are many fine-tuning steps\n**\n\nIf you wish to fine-tune your model's dialog and tool capabilities, you should organize your data in the following\nformat.\n\n```json\n[\n  {\n    \"tools\": [\n      // available tools, format is not restricted\n    ],\n    \"conversations\": [\n      {\n        \"role\": \"system\",\n        \"content\": \"<system prompt text>\"\n      },\n      {\n        \"role\": \"user\",\n        \"content\": \"<user prompt text>\"\n      },\n      {\n        \"role\": \"assistant\",\n        \"content\": \"<assistant thought to text>\"\n      },\n      {\n        \"role\": \"tool\",\n        \"name\": \"<name of the tool to be called\",\n        \"parameters\": {\n          \"<parameter_name>\": \"<parameter_value>\"\n        },\n        \"observation\": \"<observation>\"\n        // don't have to be string\n      },\n      {\n        \"role\": \"assistant\",\n        \"content\": \"<assistant response to observation>\"\n      },\n      // ... Muti Turn\n      {\n        \"role\": \"user\",\n        \"content\": \"<user prompt text>\"\n      },\n      {\n        \"role\": \"assistant\",\n        \"content\": \"<assistant response text>\"\n      }\n    ]\n  }\n  // ...\n]\n```\n\n- There is no need to manually insert the system prompt about the tool description. The `tools` field will be used\n  during preprocessing using `json.dumps(..., ensure_ascii=False)`\n  After formatting, insert it as the first system prompt.\n\n- Each role can be accompanied by a `loss` field of type `bool`, indicating whether the content predicted by this field\n  participates in `loss`\n  calculate. If there is no such field, the sample implementation does not calculate `loss` for `system` and `user` by\n  default, but calculates `loss` for other roles.\n\n- `tool` is not a native role in ChatGLM3. The `tool` here will be automatically converted into an `assistant` with tool\n  call `metadata` during the preprocessing stage.\n  role (default `loss` is calculated) and an `observation` role representing the tool return value (`loss` is not\n  calculated).\n\n- The fine-tuning task of `Code interpreter` has not been implemented yet.\n\n- The `system` role is optional, but if the `system` role exists, it must appear in `user`\n  Before the character, the `system` character can only appear once in a complete dialogue data (regardless of single\n  round or multiple rounds of dialogue).\n\n## Dataset format example\n\nHere we take the AdvertiseGen data set as an example,\nYou can download it\nfrom [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing)\nOr [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) download the AdvertiseGen data set.\nPlace the decompressed AdvertiseGen directory in the `data` directory and convert it into the following format data set\nyourself.\n\n> Please note that the verification set is added to the current fine-tuning code. Therefore, for a complete set of\n> fine-tuning data sets, the training data set and the verification data set must be included, and the test data set\n> does\n> not need to be filled in. Or directly use the validation data set instead.\n\n```\n{\"conversations\": [{\"role\": \"user\", \"content\": \"Type#skirt*skirt length#skirt\"}, {\"role\": \"assistant\", \"content\": \"This is versatile Fashionable fairy skirt, the overall design is very elegant and casual. Every girl can instantly turn into a fairy after wearing it. The material is very light and breathable, making it very comfortable to wear in summer.\"} ]}\n```\n\n## Configuration file\n\nFine-tuning configuration files are located in the `config` directory and include the following files:\n\n1. `ds_zereo_2 / ds_zereo_3.json`: deepspeed configuration file.\n2. `lora.yaml / ptuning.yaml / sft.yaml`: Configuration files for different models, including model parameters,\n   optimizer parameters, training parameters, etc. Some important parameters are explained as follows:\n    + data_config section\n        + train_file: The file path of the training data set.\n        + val_file: The file path of the verification data set.\n        + test_file: The file path of the test data set.\n        + num_proc: Number of processes used when loading data.\n    + max_input_length: The maximum length of the input sequence.\n    + max_output_length: The maximum length of the output sequence.\n    + training_args section\n        + output_dir: Directory for saving models and other outputs.\n        + max_steps: The maximum number of steps for training.\n        + per_device_train_batch_size: training batch size per device (e.g. GPU).\n        + dataloader_num_workers: The number of worker threads used when loading data.\n        + remove_unused_columns: Whether to remove unused columns in the data.\n        + save_strategy: model saving strategy (for example, how many steps should be saved).\n        + save_steps: How many steps should be taken to save the model.\n        + log_level: log level (such as info).\n        + logging_strategy: logging strategy.\n        + logging_steps: How many steps to log.\n        + per_device_eval_batch_size: Evaluation batch size per device.\n        + evaluation_strategy: Evaluation strategy (e.g. how many steps should be evaluated).\n        + eval_steps: How many steps to evaluate.\n        + predict_with_generate: Whether to use generate mode for prediction.\n    + generation_config section\n        + max_new_tokens: The maximum number of new tokens generated.\n    + peft_config section\n        + peft_type: The parameter valid adjustment type used (e.g. LORA).\n        + task_type: task type, here is the causal language model (CAUSAL_LM).\n    + Lora parameters:\n        + r: LoRA rank.\n        + lora_alpha: Scaling factor for LoRA.\n        + lora_dropout: dropout probability used in LoRA layer\n    + P-TuningV2 parameters:\n        + num_virtual_tokens: The number of virtual tokens.\n\n## Start fine-tuning\n\nUse the following code to execute **single machine multiple cards/multiple machines multiple cards** operation.\n\n```angular2html\ncd finetune_demo\nOMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8  finetune_hf.py  data/AdvertiseGen/  THUDM/chatglm3-6b  configs/lora.yaml \n```\n\nExecute **Single machine single card** operation through the following code.\n\n```angular2html\ncd finetune_demo\npython finetune_hf.py data/AdvertiseGen/ THUDM/chatglm3-6b configs/lora.yaml\n```\n\n## Fine-tuning from a checkpoint\n\nIf you train according to the above method, each fine-tuning will start from scratch. If you want to fine-tune from a\nhalf-trained model, you can add a fourth parameter, which has two ways to pass in:\n\n1. `yes`, automatically start training from the last saved Checkpoint\n2. `XX`, breakpoint number, for example, `600` means training from Checkpoint number 600\n\nFor example, this is an example of continuing fine-tuning from the last saved point\n```angular2html\n\ncd finetune_demo\npython finetune_hf.py  data/AdvertiseGen/  THUDM/chatglm3-6b  configs/lora.yaml yes\n```\n\n## Use the fine-tuned model\n\n### Verify the fine-tuned model in inference_hf.py\n\nYou can use our fine-tuned model in `finetune_demo/inference_hf.py`, which can be easily tested with just one line of\ncode.\n\n```angular2html\npython inference_hf.py your_finetune_path --prompt your prompt\n```\n\nIn this way, the answer you get is a fine-tuned answer.\n\n### Use the fine-tuned model in other demos in this repos or external repos\n\nYou can use our `lora` and fully parameterized fine-tuned models in any demo, as follows:\n\n1. Use the method of reading the model in `finetune_demo/inference_hf.py` to replace the method of reading the model in\n   the demo.\n\n> Please note that for LORA and P-TuningV2 we do not merge the trained models, but in `adapter_config.json`\n> The fine-tuning path is recorded in . If your original model location changes, you should modify the path\n> of `base_model_name_or_path` in `adapter_config.json`.\n\n\n> Please note that we have only tested using NVIDIA Hopper (representative GPU: H100) and Ampère (representative GPU:\n> A100) architecture and series of graphics cards. If you use a graphics card with another architecture, you may\n> experience\n> 1. Unknown training problem/Video memory usage is different from the above.\n> 2. The architecture is too low and does not support certain features.\n> 3. The problem of reasoning effect.\n     > The above three situations are problems that the community has encountered before. Although the probability is\n     extremely low, if you encounter the above problems, you can try to solve them in the community.\n\n```python\ndef load_model_and_tokenizer(\n        model_dir: Union[str, Path], trust_remote_code: bool = True\n) -> tuple[ModelType, TokenizerType]:\n    model_dir = _resolve_path(model_dir)\n    if (model_dir / 'adapter_config.json').exists():\n        model = AutoPeftModelForCausalLM.from_pretrained(\n            model_dir, trust_remote_code=trust_remote_code, device_map='auto'\n        )\n        tokenizer_dir = model.peft_config['default'].base_model_name_or_path\n    else:\n        model = AutoModelForCausalLM.from_pretrained(\n            model_dir, trust_remote_code=trust_remote_code, device_map='auto'\n        )\n        tokenizer_dir = model_dir\n    tokenizer = AutoTokenizer.from_pretrained(\n        tokenizer_dir, trust_remote_code=trust_remote_code\n    )\n    return model, tokenizer\n```\n\n2. Read the fine-tuned model, please note that you should use the location of the fine-tuned model, for example, if your\n   model location is `/path/to/finetune_adapter_model`\n   , the original model address is `path/to/base_model`, then you should use `/path/to/finetune_adapter_model`\n   as `model_dir`.\n3. After completing the above operations, you can use the fine-tuned model normally, and other calling methods remain\n   unchanged.\n\n### hint\n\n1. Before starting training, the fine-tuning code will print the preprocessing information of the first training data (\n   it is commented by default and can be uncommented), which is displayed as\n\n```log\nSanity\nCheck >> >> >> >> >> >> >\n'[gMASK]': 64790 ->   -100\n'sop': 64792 ->   -100\n'<|system|>': 64794 ->   -100\n'': 30910 ->   -100\n'\\n': 13 ->   -100\n'Answer': 20115 ->   -100\n'the': 267 ->   -100\n'following': 1762 ->   -100\n...\n'know': 683 ->   -100\n'the': 267 ->   -100\n'response': 3010 ->   -100\n'details': 3296 ->   -100\n'.': 30930 ->   -100\n'<|assistant|>': 64796 ->   -100\n'': 30910 ->  30910\n'\\n': 13 ->     13\n'I': 307 ->    307\n'need': 720 ->    720\n'to': 289 ->    289\n'use': 792 ->    792\n...\n<< << << << << << < Sanity\nCheck\n```\n\nwords, each line represents a detokenized string, token_id and target_id in turn. Among them, `target_id` is the index\nof `token_id` in the model vocabulary, and `-100` means that\nToken does not participate in `loss` calculation.\n\n2. The function of `_prepare_model_for_training` is to iterate through all the trainable parameters of the model and\n   ensure that their data type is `torch.float32`.\n   This is necessary in some cases because mixed precision training or other operations may change the data type of the\n   model parameters. This code is opened by default and can be commented, but if you use\n   If there is a problem with `half` format training, you can switch back to this code, and the video memory may\n   increase.\n3. In our [Huggingface model code](https://huggingface.co/THUDM/chatglm3-6b/blob/main/modeling_chatglm.py), there is the\n   following content:\n    ```python\n   if self.gradient_checkpointing and self.training:\n                layer_ret = torch.utils.checkpoint.checkpoint(\n                    layer,\n                    hidden_states,\n                    attention_mask,\n                    rotary_pos_emb,\n                    kv_caches[index],\n                    use_cache,\n                    use_reentrant=False\n                )\n   ```\n   This may cause the video memory to increase during training, so if you have insufficient video memory, you can try\n   changing ``` use_reentrant``` to `True`.\n4. The fine-tuned model can use any model acceleration framework that supports `peft` loading. Here, we do not provide a\n   demo.\n5. There are certain differences between the fine-tuning data set format of this warehouse and the API fine-tuning data\n   set format.\n    + The `messages` field in the ZhipuAI API fine-tuning data set is the `conversation` field in this warehouse.\n    + The fine-tuning file in ZhipuAI API is `jsonl`. In this warehouse, you need to simply change the file name\n      to `json`.\n\n## Citation\n\n```\n\n@inproceedings{liu2022p,\ntitle={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},\nauthor={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},\nbooktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short\nPapers)},\npages={61--68},\nyear={2022}\n}\n\n@misc{tang2023toolalpaca,\ntitle={ToolAlpaca: Generalized Tool Learning for Language Models with 3000 Simulated Cases},\nauthor={Qiaoyu Tang and Ziliang Deng and Hongyu Lin and Xianpei Han and Qiao Liang and Le Sun},\nyear={2023},\neprint={2306.05301},\narchivePrefix={arXiv},\nprimaryClass={cs.CL}\n}\n\n```"
  },
  {
    "path": "finetune_demo/configs/ds_zero_2.json",
    "content": "{\n    \"fp16\": {\n        \"enabled\": \"auto\",\n        \"loss_scale\": 0,\n        \"loss_scale_window\": 1000,\n        \"initial_scale_power\": 16,\n        \"hysteresis\": 2,\n        \"min_loss_scale\": 1\n    },\n    \"bf16\": {\n        \"enabled\": \"auto\"\n    },\n    \"zero_optimization\": {\n        \"stage\": 2,\n        \"allgather_partitions\": true,\n        \"allgather_bucket_size\": 5e8,\n        \"overlap_comm\": true,\n        \"reduce_scatter\": true,\n        \"reduce_bucket_size\": 5e8,\n        \"contiguous_gradients\": true\n    },\n\n    \"gradient_accumulation_steps\": \"auto\",\n    \"gradient_clipping\": \"auto\",\n    \"steps_per_print\": 2000,\n    \"train_batch_size\": \"auto\",\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"wall_clock_breakdown\": false\n}"
  },
  {
    "path": "finetune_demo/configs/ds_zero_3.json",
    "content": "{\n  \"train_micro_batch_size_per_gpu\": \"auto\",\n  \"zero_allow_untested_optimizer\": true,\n  \"bf16\": {\n    \"enabled\": \"auto\"\n  },\n  \"optimizer\": {\n    \"type\": \"AdamW\",\n    \"params\": {\n      \"lr\": \"auto\",\n      \"betas\": \"auto\",\n      \"eps\": \"auto\",\n      \"weight_decay\": \"auto\"\n    }\n  },\n  \"zero_optimization\": {\n    \"stage\": 3,\n    \"allgather_partitions\": true,\n    \"allgather_bucket_size\": 5e8,\n    \"reduce_scatter\": true,\n    \"contiguous_gradients\": true,\n    \"overlap_comm\": true,\n    \"sub_group_size\": 1e9,\n    \"reduce_bucket_size\": \"auto\",\n    \"stage3_prefetch_bucket_size\": \"auto\",\n    \"stage3_param_persistence_threshold\": \"auto\",\n    \"stage3_max_live_parameters\": 1e9,\n    \"stage3_max_reuse_distance\": 1e9,\n    \"stage3_gather_16bit_weights_on_model_save\": true\n  }\n}"
  },
  {
    "path": "finetune_demo/configs/lora.yaml",
    "content": "data_config:\n  train_file: train.json\n  val_file: dev.json\n  test_file: dev.json\n  num_proc: 16\nmax_input_length: 256\nmax_output_length: 512\ntraining_args:\n  # see `transformers.Seq2SeqTrainingArguments`\n  output_dir: ./output\n  max_steps: 3000\n  # needed to be fit for the dataset\n  learning_rate: 5e-5\n  # settings for data loading\n  per_device_train_batch_size: 4\n  dataloader_num_workers: 16\n  remove_unused_columns: false\n  # settings for saving checkpoints\n  save_strategy: steps\n  save_steps: 500\n  # settings for logging\n  log_level: info\n  logging_strategy: steps\n  logging_steps: 10\n  # settings for evaluation\n  per_device_eval_batch_size: 16\n  evaluation_strategy: steps\n  eval_steps: 500\n  # settings for optimizer\n  # adam_epsilon: 1e-6\n  # uncomment the following line to detect nan or inf values\n  # debug: underflow_overflow\n  predict_with_generate: true\n  # see `transformers.GenerationConfig`\n  generation_config:\n    max_new_tokens: 512\n  # set your absolute deepspeed path here\n  #deepspeed: ds_zero_2.json\n  # set to true if train with cpu.\n  use_cpu: false\npeft_config:\n  peft_type: LORA\n  task_type: CAUSAL_LM\n  r: 8\n  lora_alpha: 32\n  lora_dropout: 0.1\n"
  },
  {
    "path": "finetune_demo/configs/ptuning_v2.yaml",
    "content": "data_config:\n  train_file: train.json\n  val_file: dev.json\n  test_file: dev.json\n  num_proc: 16\nmax_input_length: 256\nmax_output_length: 512\ntraining_args:\n  # see `transformers.Seq2SeqTrainingArguments`\n  output_dir: ./output\n  max_steps: 3000\n  # needed to be fit for the dataset\n  learning_rate: 5e-5\n  # settings for data loading\n  per_device_train_batch_size: 4\n  dataloader_num_workers: 16\n  remove_unused_columns: false\n  # settings for saving checkpoints\n  save_strategy: steps\n  save_steps: 500\n  # settings for logging\n  log_level: info\n  logging_strategy: steps\n  logging_steps: 10\n  # settings for evaluation\n  per_device_eval_batch_size: 16\n  evaluation_strategy: steps\n  eval_steps: 500\n  # settings for optimizer\n  # adam_epsilon: 1e-6\n  # uncomment the following line to detect nan or inf values\n  # debug: underflow_overflow\n  predict_with_generate: true\n  # see `transformers.GenerationConfig`\n  generation_config:\n    max_new_tokens: 512\n  # set your absolute deepspeed path here\n  #deepspeed: ds_zero_3.json\n  use_cpu: false\npeft_config:\n  peft_type: PREFIX_TUNING\n  task_type: CAUSAL_LM\n  num_virtual_tokens: 128\n"
  },
  {
    "path": "finetune_demo/configs/sft.yaml",
    "content": "data_config:\n  train_file: train.json\n  val_file: dev.json\n  test_file: dev.json\n  num_proc: 16\nmax_input_length: 256\nmax_output_length: 512\ntraining_args:\n  # see `transformers.Seq2SeqTrainingArguments`\n  output_dir: ./output\n  max_steps: 3000\n  # needed to be fit for the dataset\n  learning_rate: 5e-5\n  # settings for data loading\n  per_device_train_batch_size: 4\n  dataloader_num_workers: 16\n  remove_unused_columns: false\n  # settings for saving checkpoints\n  save_strategy: steps\n  save_steps: 500\n  # settings for logging\n  log_level: info\n  logging_strategy: steps\n  logging_steps: 10\n  # settings for evaluation\n  per_device_eval_batch_size: 16\n  evaluation_strategy: steps\n  eval_steps: 500\n  # settings for optimizer\n  # adam_epsilon: 1e-6\n  # uncomment the following line to detect nan or inf values\n  # debug: underflow_overflow\n  predict_with_generate: true\n  generation_config:\n    max_new_tokens: 512\n  # set your absolute deepspeed path here\n  deepspeed: ds_zero_3.json\n"
  },
  {
    "path": "finetune_demo/finetune_hf.py",
    "content": "# -*- coding: utf-8 -*-\nimport os\nimport jieba\nimport dataclasses as dc\nimport functools\nfrom collections.abc import Callable, Mapping, Sequence\nfrom pathlib import Path\nfrom typing import Annotated, Any, Optional, Union\nimport numpy as np\nimport ruamel.yaml as yaml\nimport torch\nimport typer\nfrom datasets import Dataset, DatasetDict, NamedSplit, Split, load_dataset\nfrom nltk.translate.bleu_score import SmoothingFunction, sentence_bleu\nfrom peft import (\n    PeftConfig,\n    PeftModelForCausalLM,\n    get_peft_config,\n    get_peft_model\n)\nfrom rouge_chinese import Rouge\nfrom torch import nn\nfrom transformers import (\n    AutoModelForCausalLM,\n    AutoTokenizer,\n    EvalPrediction,\n    GenerationConfig,\n    PreTrainedModel,\n    PreTrainedTokenizer,\n    PreTrainedTokenizerFast,\n    Seq2SeqTrainingArguments, AutoConfig,\n)\nfrom transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq\n\nfrom transformers import Seq2SeqTrainer as _Seq2SeqTrainer\n\nModelType = Union[PreTrainedModel, PeftModelForCausalLM]\nTokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]\napp = typer.Typer(pretty_exceptions_show_locals=False)\n\n\nclass DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):\n    def __call__(self, features, return_tensors=None):\n        output_ids = (\n            [feature['output_ids'] for feature in features]\n            if 'output_ids' in features[0].keys()\n            else None\n        )\n        if output_ids is not None:\n            max_output_length = max(len(out) for out in output_ids)\n            if self.pad_to_multiple_of is not None:\n                max_output_length = (\n                        (\n                                max_output_length + self.pad_to_multiple_of - 1) //\n                        self.pad_to_multiple_of * self.pad_to_multiple_of\n                )\n            for feature in features:\n                remainder = [self.tokenizer.pad_token_id] * (\n                        max_output_length - len(feature['output_ids'])\n                )\n                if isinstance(feature['output_ids'], list):\n                    feature['output_ids'] = feature['output_ids'] + remainder\n                else:\n                    feature['output_ids'] = np.concatenate(\n                        [feature['output_ids'], remainder]\n                    ).astype(np.int64)\n        return super().__call__(features, return_tensors)\n\n\nclass Seq2SeqTrainer(_Seq2SeqTrainer):\n    def prediction_step(\n            self,\n            model: nn.Module,\n            inputs: dict[str, Any],\n            prediction_loss_only: bool,\n            ignore_keys=None,\n            **gen_kwargs,\n    ) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:\n        if self.args.predict_with_generate:\n            output_ids = inputs.pop('output_ids')\n        input_ids = inputs['input_ids']\n        loss, generated_tokens, labels = super().prediction_step(\n            model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs\n        )\n        generated_tokens = generated_tokens[:, input_ids.size()[1]:]\n        if self.args.predict_with_generate:\n            labels = output_ids\n        return loss, generated_tokens, labels\n    # For P-Tuning a new save_model function is fine for the prefix_encoder model\n    # but may cost problems for the whole model loading\n\n    # def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):\n    #     if output_dir is None:\n    #         output_dir = self.args.output_dir\n    #     os.makedirs(output_dir, exist_ok=True)\n    #     ptuning_params = {k: v for k, v in self.model.transformer.prefix_encoder.state_dict().items()}\n    #\n    #     torch.save(ptuning_params, os.path.join(output_dir, 'pytorch_model.bin'))\n    #\n    #     print(f\"P-Tuning model weights saved in {output_dir}\")\n    #\n    #     if self.tokenizer is not None:\n    #         self.tokenizer.save_pretrained(output_dir)\n\n\ndef _resolve_path(path: Union[str, Path]) -> Path:\n    return Path(path).expanduser().resolve()\n\n\ndef _sanity_check(\n        input_ids: Sequence[int],\n        output_ids: Sequence[int],\n        tokenizer: PreTrainedTokenizer,\n):\n    print('--> Sanity check')\n    for in_id, out_id in zip(input_ids, output_ids):\n        if in_id == 0:\n            continue\n        if in_id in tokenizer.tokenizer.index_special_tokens:\n            in_text = tokenizer.tokenizer.index_special_tokens[in_id]\n        else:\n            in_text = tokenizer.decode([in_id])\n        print(f'{repr(in_text):>20}: {in_id} -> {out_id}')\n\n\n@functools.cache\ndef _get_yaml_parser() -> yaml.YAML:\n    parser = yaml.YAML(typ='safe', pure=True)\n    parser.indent(mapping=2, offset=2, sequence=4)\n    parser.default_flow_style = False\n    return parser\n\n\n@dc.dataclass\nclass DataConfig(object):\n    train_file: str\n    val_file: Optional[str] = None\n    test_file: Optional[str] = None\n\n    num_proc: Optional[int] = None\n\n    @property\n    def data_format(self) -> str:\n        return Path(self.train_file).suffix\n\n    @property\n    def data_files(self) -> dict[NamedSplit, str]:\n        return {\n            split: data_file\n            for split, data_file in zip(\n                [Split.TRAIN, Split.VALIDATION, Split.TEST],\n                [self.train_file, self.val_file, self.test_file],\n            )\n            if data_file is not None\n        }\n\n\n@dc.dataclass\nclass FinetuningConfig(object):\n    data_config: DataConfig\n\n    max_input_length: int\n    max_output_length: int\n\n    training_args: Seq2SeqTrainingArguments = dc.field(\n        default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')\n    )\n    peft_config: Optional[PeftConfig] = None\n\n    def __post_init__(self):\n        if not self.training_args.do_eval or self.data_config.val_file is None:\n            # skips the evaluation stage when `do_eval` or `eval_file` is not provided\n            self.training_args.do_eval = False\n            self.training_args.evaluation_strategy = 'no'\n            self.data_config.val_file = None\n        else:\n            self.training_args.per_device_eval_batch_size = (\n                    self.training_args.per_device_eval_batch_size\n                    or self.training_args.per_device_train_batch_size\n            )\n\n    @classmethod\n    def from_dict(cls, **kwargs) -> 'FinetuningConfig':\n        training_args = kwargs.get('training_args', None)\n        if training_args is not None and not isinstance(\n                training_args, Seq2SeqTrainingArguments\n        ):\n            gen_config = training_args.get('generation_config')\n            # TODO: a bit hacky\n            if not isinstance(gen_config, GenerationConfig):\n                training_args['generation_config'] = GenerationConfig(\n                    **gen_config\n                )\n            kwargs['training_args'] = Seq2SeqTrainingArguments(**training_args)\n\n        data_config = kwargs.get('data_config')\n        if not isinstance(data_config, DataConfig):\n            kwargs['data_config'] = DataConfig(**data_config)\n\n        peft_config = kwargs.get('peft_config', None)\n        if peft_config is not None and not isinstance(peft_config, PeftConfig):\n            kwargs['peft_config'] = get_peft_config(peft_config)\n        return cls(**kwargs)\n\n    @classmethod\n    def from_file(cls, path: Union[str, Path]) -> 'FinetuningConfig':\n        path = _resolve_path(path)\n        kwargs = _get_yaml_parser().load(path)\n        return cls.from_dict(**kwargs)\n\n\ndef _load_datasets(\n        data_dir: Path,\n        data_format: str,\n        data_files: dict[NamedSplit, str],\n        num_proc: Optional[int],\n) -> DatasetDict:\n    if data_format in ('.csv', '.json', '.jsonl'):\n        dataset_dct = load_dataset(\n            data_format[1:],\n            data_dir=data_dir,\n            data_files=data_files,\n            num_proc=num_proc,\n        )\n    else:\n        err_msg = f\"Cannot load dataset in the '{data_format}' format.\"\n        raise NotImplementedError(err_msg)\n\n    return dataset_dct\n\n\nclass DataManager(object):\n    def __init__(self, data_dir: str, data_config: DataConfig):\n        self._num_proc = data_config.num_proc\n\n        self._dataset_dct = _load_datasets(\n            _resolve_path(data_dir),\n            data_config.data_format,\n            data_config.data_files,\n            self._num_proc,\n        )\n\n    def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]:\n        return self._dataset_dct.get(split, None)\n\n    def get_dataset(\n            self,\n            split: NamedSplit,\n            process_fn: Callable[[dict[str, Any]], dict[str, Any]],\n            batched: bool = True,\n            remove_orig_columns: bool = True,\n    ) -> Optional[Dataset]:\n        orig_dataset = self._get_dataset(split)\n        if orig_dataset is None:\n            return\n\n        if remove_orig_columns:\n            remove_columns = orig_dataset.column_names\n        else:\n            remove_columns = None\n        return orig_dataset.map(\n            process_fn,\n            batched=batched,\n            remove_columns=remove_columns,\n            num_proc=self._num_proc,\n        )\n\n\ndef print_model_size(model: PreTrainedModel):\n    print(\"--> Model\")\n    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n    print(f\"\\n--> model has {total_params / 1e6}M params\\n\")\n\n\ndef process_batch(\n        batch: Mapping[str, Sequence],\n        tokenizer: PreTrainedTokenizer,\n        max_input_length: int,\n        max_output_length: int,\n) -> dict[str, list]:\n    batched_tools = batch.get('tools', None)\n    batched_conv = batch['conversations']\n    batched_input_ids = []\n    batched_labels = []\n\n    if batched_tools is None:\n        batched_tools = [None] * len(batched_conv)\n\n    for tools, conv in zip(batched_tools, batched_conv):\n        input_ids, loss_masks = [\n            tokenizer.get_command('[gMASK]'),\n            tokenizer.get_command('sop'),\n        ], [False, False]\n\n        if tools is not None:\n            raise NotImplementedError()\n\n        for message in conv:\n            if message['role'] in ('system', 'user'):\n                loss_mask_val = False\n            else:\n                loss_mask_val = True\n\n            if message['role'] == 'tool':\n                raise NotImplementedError()\n            else:\n                new_input_ids = tokenizer.build_single_message(\n                    message['role'], '', message['content']\n                )\n                new_loss_masks = [loss_mask_val] * len(new_input_ids)\n\n            input_ids += new_input_ids\n            loss_masks += new_loss_masks\n\n        input_ids.append(tokenizer.eos_token_id)\n        loss_masks = [False, *loss_masks]\n        labels = []\n        for input_id, mask in zip(input_ids, loss_masks):\n            if mask:\n                labels.append(input_id)\n            else:\n                labels.append(-100)\n        max_length = max_input_length + max_output_length + 1\n        batched_input_ids.append(input_ids[:max_length])\n        batched_labels.append(labels[:max_length])\n    return {'input_ids': batched_input_ids, 'labels': batched_labels}\n\n\ndef process_batch_eval(\n        batch: Mapping[str, Sequence],\n        tokenizer: PreTrainedTokenizer,\n        max_input_length: int,\n        max_output_length: int,\n) -> dict[str, list]:\n    batched_tools = batch.get('tools', None)\n    batched_conv = batch['conversations']\n    batched_input_ids = []\n    # To avoid computing loss, we do not provide the `labels` field in the input dictionary.\n    batched_output_ids = []\n\n    if batched_tools is None:\n        batched_tools = [None] * len(batched_conv)\n\n    for tools, conv in zip(batched_tools, batched_conv):\n        input_ids = [\n            tokenizer.get_command('[gMASK]'),\n            tokenizer.get_command('sop'),\n        ]\n\n        if tools is not None:\n            raise NotImplementedError()\n\n        for message in conv:\n            if len(input_ids) >= max_input_length:\n                break\n            if message['role'] == 'tool':\n                raise NotImplementedError()\n            else:\n                new_input_ids = tokenizer.build_single_message(\n                    message['role'], '', message['content']\n                )\n                if message['role'] == 'assistant':\n                    output_prompt, output_ids = (\n                        new_input_ids[:1],\n                        new_input_ids[1:],\n                    )\n                    output_ids.append(tokenizer.eos_token_id)\n                    batched_input_ids.append(\n                        input_ids[:max_input_length] + output_prompt[:1]\n                    )\n                    batched_output_ids.append(output_ids[:max_output_length])\n                input_ids += new_input_ids\n    return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids}\n\n\n# Not sure if this is necessary, can set it to half.\n# If train with cpu, cast all params to fp32 instead of trainable ones.\ndef _prepare_model_for_training(model: nn.Module, use_cpu: bool):\n    for param in model.parameters():\n        if param.requires_grad or use_cpu:\n            param.data = param.data.to(torch.float32)\n\n\ndef load_tokenizer_and_model(\n        model_dir: str,\n        peft_config: Optional[PeftConfig] = None,\n) -> tuple[PreTrainedTokenizer, nn.Module]:\n    tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)\n    if peft_config is not None:\n        if peft_config.peft_type.name == \"PREFIX_TUNING\":\n            config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)\n            config.pre_seq_len = peft_config.num_virtual_tokens\n            config.use_cache = False\n            model = AutoModelForCausalLM.from_pretrained(\n                model_dir,\n                trust_remote_code=True,\n                config=config,\n            )\n        if peft_config.peft_type.name == \"LORA\":\n            model = AutoModelForCausalLM.from_pretrained(\n                model_dir,\n                trust_remote_code=True,\n                empty_init=False,\n                use_cache=False\n            )\n            model = get_peft_model(model, peft_config)\n            model.print_trainable_parameters()\n    else:\n        model = AutoModelForCausalLM.from_pretrained(\n            model_dir,\n            trust_remote_code=True,\n            empty_init=False,\n            use_cache=False\n        )\n    print_model_size(model)\n    return tokenizer, model\n\n\ndef compute_metrics(eval_preds: EvalPrediction, tokenizer: PreTrainedTokenizer):\n    batched_pred_ids, batched_label_ids = eval_preds\n\n    metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}\n    for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):\n        pred_txt = tokenizer.decode(pred_ids).strip()\n        label_txt = tokenizer.decode(label_ids).strip()\n        pred_tokens = list(jieba.cut(pred_txt))\n        label_tokens = list(jieba.cut(label_txt))\n        rouge = Rouge()\n        scores = rouge.get_scores(' '.join(pred_tokens), ' '.join(label_tokens))\n        for k, v in scores[0].items():\n            metrics_dct[k].append(round(v['f'] * 100, 4))\n        metrics_dct['bleu-4'].append(\n            sentence_bleu(\n                [label_tokens],\n                pred_tokens,\n                smoothing_function=SmoothingFunction().method3,\n            )\n        )\n    return {k: np.mean(v) for k, v in metrics_dct.items()}\n\n\n@app.command()\ndef main(\n        data_dir: Annotated[str, typer.Argument(help='')],\n        model_dir: Annotated[\n            str,\n            typer.Argument(\n                help='A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file.'\n            ),\n        ],\n        config_file: Annotated[str, typer.Argument(help='')],\n        auto_resume_from_checkpoint: str = typer.Argument(\n            default='',\n            help='If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training'\n        ),\n\n):\n    ft_config = FinetuningConfig.from_file(config_file)\n    tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)\n    data_manager = DataManager(data_dir, ft_config.data_config)\n\n    train_dataset = data_manager.get_dataset(\n        Split.TRAIN,\n        functools.partial(\n            process_batch,\n            tokenizer=tokenizer,\n            max_input_length=ft_config.max_input_length,\n            max_output_length=ft_config.max_output_length,\n        ),\n        batched=True,\n    )\n    print('train_dataset:', train_dataset)\n    val_dataset = data_manager.get_dataset(\n        Split.VALIDATION,\n        functools.partial(\n            process_batch_eval,\n            tokenizer=tokenizer,\n            max_input_length=ft_config.max_input_length,\n            max_output_length=ft_config.max_output_length,\n        ),\n        batched=True,\n    )\n    if val_dataset is not None:\n        print('val_dataset:', val_dataset)\n    test_dataset = data_manager.get_dataset(\n        Split.TEST,\n        functools.partial(\n            process_batch_eval,\n            tokenizer=tokenizer,\n            max_input_length=ft_config.max_input_length,\n            max_output_length=ft_config.max_output_length,\n        ),\n        batched=True,\n    )\n    if test_dataset is not None:\n        print('test_dataset:', test_dataset)\n\n    # checks encoded dataset\n    _sanity_check(\n        train_dataset[0][\"input_ids\"], train_dataset[0][\"labels\"], tokenizer\n    )\n\n    # turn model to fp32\n    _prepare_model_for_training(model, ft_config.training_args.use_cpu)\n\n    ft_config.training_args.generation_config.pad_token_id = (\n        tokenizer.pad_token_id\n    )\n    ft_config.training_args.generation_config.eos_token_id = [\n        tokenizer.eos_token_id,\n        tokenizer.get_command('<|user|>'),\n        tokenizer.get_command('<|observation|>'),\n    ]\n    model.gradient_checkpointing_enable()\n    model.enable_input_require_grads()\n\n    use_tokenizer = True\n    if ft_config.peft_config is not None:\n        use_tokenizer = False if ft_config.peft_config.peft_type == \"LORA\" else True\n\n    trainer = Seq2SeqTrainer(\n        model=model,\n        args=ft_config.training_args,\n        data_collator=DataCollatorForSeq2Seq(\n            tokenizer=tokenizer,\n            padding='longest',\n            return_tensors='pt',\n        ),\n        train_dataset=train_dataset,\n        eval_dataset=val_dataset.select(list(range(50))),\n        tokenizer=tokenizer if use_tokenizer else None,  # LORA does not need tokenizer\n        compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),\n    )\n\n    if auto_resume_from_checkpoint.upper() == \"\" or auto_resume_from_checkpoint is None:\n        trainer.train()\n    else:\n        def do_rf_checkpoint(sn):\n            model.gradient_checkpointing_enable()\n            model.enable_input_require_grads()\n            checkpoint_directory = os.path.join(output_dir, \"checkpoint-\" + sn)\n            print(\"resume checkpoint from  checkpoint-\" + sn)\n            trainer.train(resume_from_checkpoint=checkpoint_directory)\n\n        output_dir = ft_config.training_args.output_dir\n\n        # resume from latest checkpoint\n        if auto_resume_from_checkpoint.upper() == \"YES\":\n            dirlist = os.listdir(output_dir)\n            checkpoint_sn = 0\n            # get latest checkpoint\n            for checkpoint_str in dirlist:\n                if checkpoint_str.find(\"eckpoint\") > 0 and checkpoint_str.find(\"tmp\") == -1:\n                    checkpoint = int(checkpoint_str.replace(\"checkpoint-\", \"\"))\n                    if checkpoint > checkpoint_sn:\n                        checkpoint_sn = checkpoint\n            if checkpoint_sn > 0:\n                do_rf_checkpoint(str(checkpoint_sn))\n            else:\n                trainer.train()\n        else:\n            # resume from specific checkpoint\n            if auto_resume_from_checkpoint.isdigit() and int(auto_resume_from_checkpoint) > 0:\n                do_rf_checkpoint(auto_resume_from_checkpoint)\n            else:\n                print(auto_resume_from_checkpoint,\n                      \"The specified checkpoint sn(\" + auto_resume_from_checkpoint + \") has not been saved. Please search for the correct chkeckpoint in the model output directory\")\n\n    # test stage\n    if test_dataset is not None:\n        trainer.predict(test_dataset)\n\n\nif __name__ == '__main__':\n    app()\n"
  },
  {
    "path": "finetune_demo/inference_hf.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\nfrom pathlib import Path\nfrom typing import Annotated, Union\n\nimport typer\nfrom peft import AutoPeftModelForCausalLM, PeftModelForCausalLM\nfrom transformers import (\n    AutoModelForCausalLM,\n    AutoTokenizer,\n    PreTrainedModel,\n    PreTrainedTokenizer,\n    PreTrainedTokenizerFast,\n)\n\nModelType = Union[PreTrainedModel, PeftModelForCausalLM]\nTokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]\n\napp = typer.Typer(pretty_exceptions_show_locals=False)\n\n\ndef _resolve_path(path: Union[str, Path]) -> Path:\n    return Path(path).expanduser().resolve()\n\n\ndef load_model_and_tokenizer(model_dir: Union[str, Path]) -> tuple[ModelType, TokenizerType]:\n    model_dir = _resolve_path(model_dir)\n    if (model_dir / 'adapter_config.json').exists():\n        model = AutoPeftModelForCausalLM.from_pretrained(\n            model_dir, trust_remote_code=True, device_map='auto'\n        )\n        tokenizer_dir = model.peft_config['default'].base_model_name_or_path\n    else:\n        model = AutoModelForCausalLM.from_pretrained(\n            model_dir, trust_remote_code=True, device_map='auto'\n        )\n        tokenizer_dir = model_dir\n    tokenizer = AutoTokenizer.from_pretrained(\n        tokenizer_dir, trust_remote_code=True\n    )\n    return model, tokenizer\n\n\n@app.command()\ndef main(\n        model_dir: Annotated[str, typer.Argument(help='')],\n        prompt: Annotated[str, typer.Option(help='')],\n):\n    model, tokenizer = load_model_and_tokenizer(model_dir)\n    response, _ = model.chat(tokenizer, prompt)\n    print(response)\n\n\nif __name__ == '__main__':\n    app()\n"
  },
  {
    "path": "finetune_demo/lora_finetune.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"source\": [\n    \"# 单卡GPU 进行 ChatGLM3-6B模型 LORA 高效微调\\n\",\n    \"本 Cookbook 将带领开发者使用 `AdvertiseGen` 对 ChatGLM3-6B 数据集进行 lora微调，使其具备专业的广告生成能力。\\n\",\n    \"\\n\",\n    \"## 硬件需求\\n\",\n    \"显存：24GB及以上（推荐使用30系或A10等sm80架构以上的NVIDIA显卡进行尝试）\\n\",\n    \"内存：16GB\\n\",\n    \"RAM: 2.9 /16 GB\\n\",\n    \"GPU RAM: 15.5/16.0 GB\"\n   ],\n   \"metadata\": {\n    \"collapsed\": false,\n    \"id\": \"89b89f64d8f8053d\"\n   },\n   \"id\": \"89b89f64d8f8053d\"\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"source\": [\n    \"## 0. 环境检查\\n\",\n    \"首先，先检查代码的运行地址，确保运行地址处于 `finetune_demo` 中。\\n\",\n    \"并且，确保已经安装了 `requirements.txt`中的依赖。\\n\",\n    \"\\n\",\n    \"> 本 demo 中，不需要使用 deepspeed, mpi4py 两个依赖，如果您安装这两个依赖遇到问题，可以不安装这两个依赖。\"\n   ],\n   \"metadata\": {\n    \"collapsed\": false,\n    \"id\": \"a7bd9a514ed09ea6\"\n   },\n   \"id\": \"a7bd9a514ed09ea6\"\n  },\n  {\n   \"cell_type\": \"code\",\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"/media/zr/Data/Code/ChatGLM3/finetune_demo\\r\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"!pwd\"\n   ],\n   \"metadata\": {\n    \"collapsed\": false,\n    \"ExecuteTime\": {\n     \"end_time\": \"2024-04-14T05:29:22.200365Z\",\n     \"start_time\": \"2024-04-14T05:29:22.080929Z\"\n    }\n   },\n   \"id\": \"f7703109d1443346\",\n   \"execution_count\": 1\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"source\": [\n    \"## 1. 准备数据集\\n\",\n    \"我们使用 AdvertiseGen 数据集来进行微调。从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载处理好的 AdvertiseGen 数据集，将解压后的 AdvertiseGen 目录放到本目录的 `/data/` 下, 例如。\\n\",\n    \"> /media/zr/Data/Code/ChatGLM3/finetune_demo/data/AdvertiseGen\"\n   ],\n   \"metadata\": {\n    \"collapsed\": false\n   },\n   \"id\": \"2f50e92810011977\"\n  },\n  {\n   \"cell_type\": \"code\",\n   \"outputs\": [],\n   \"source\": [\n    \"import json\\n\",\n    \"from typing import Union\\n\",\n    \"from pathlib import Path\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def _resolve_path(path: Union[str, Path]) -> Path:\\n\",\n    \"    return Path(path).expanduser().resolve()\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def _mkdir(dir_name: Union[str, Path]):\\n\",\n    \"    dir_name = _resolve_path(dir_name)\\n\",\n    \"    if not dir_name.is_dir():\\n\",\n    \"        dir_name.mkdir(parents=True, exist_ok=False)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def convert_adgen(data_dir: Union[str, Path], save_dir: Union[str, Path]):\\n\",\n    \"    def _convert(in_file: Path, out_file: Path):\\n\",\n    \"        _mkdir(out_file.parent)\\n\",\n    \"        with open(in_file, encoding='utf-8') as fin:\\n\",\n    \"            with open(out_file, 'wt', encoding='utf-8') as fout:\\n\",\n    \"                for line in fin:\\n\",\n    \"                    dct = json.loads(line)\\n\",\n    \"                    sample = {'conversations': [{'role': 'user', 'content': dct['content']},\\n\",\n    \"                                                {'role': 'assistant', 'content': dct['summary']}]}\\n\",\n    \"                    fout.write(json.dumps(sample, ensure_ascii=False) + '\\\\n')\\n\",\n    \"\\n\",\n    \"    data_dir = _resolve_path(data_dir)\\n\",\n    \"    save_dir = _resolve_path(save_dir)\\n\",\n    \"\\n\",\n    \"    train_file = data_dir / 'train.json'\\n\",\n    \"    if train_file.is_file():\\n\",\n    \"        out_file = save_dir / train_file.relative_to(data_dir)\\n\",\n    \"        _convert(train_file, out_file)\\n\",\n    \"\\n\",\n    \"    dev_file = data_dir / 'dev.json'\\n\",\n    \"    if dev_file.is_file():\\n\",\n    \"        out_file = save_dir / dev_file.relative_to(data_dir)\\n\",\n    \"        _convert(dev_file, out_file)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"convert_adgen('data/AdvertiseGen', 'data/AdvertiseGen_fix')\"\n   ],\n   \"metadata\": {\n    \"collapsed\": true,\n    \"cellView\": \"form\",\n    \"id\": \"initial_id\",\n    \"ExecuteTime\": {\n     \"end_time\": \"2024-04-14T05:29:23.809255Z\",\n     \"start_time\": \"2024-04-14T05:29:22.202731Z\"\n    }\n   },\n   \"id\": \"initial_id\",\n   \"execution_count\": 2\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"source\": [\n    \"## 2. 使用命令行开始微调,我们使用 lora 进行微调\\n\",\n    \"接着，我们仅需要将配置好的参数以命令行的形式传参给程序，就可以使用命令行进行高效微调。\"\n   ],\n   \"metadata\": {\n    \"collapsed\": false,\n    \"id\": \"a1b7a99923349056\"\n   },\n   \"id\": \"a1b7a99923349056\"\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Setting eos_token is not supported, use the default one.\\r\\n\",\n      \"Setting pad_token is not supported, use the default one.\\r\\n\",\n      \"Setting unk_token is not supported, use the default one.\\r\\n\",\n      \"Loading checkpoint shards: 100%|██████████████████| 7/7 [00:02<00:00,  2.77it/s]\\r\\n\",\n      \"trainable params: 1,949,696 || all params: 6,245,533,696 || trainable%: 0.031217444255383614\\r\\n\",\n      \"--> Model\\r\\n\",\n      \"\\r\\n\",\n      \"--> model has 1.949696M params\\r\\n\",\n      \"\\r\\n\",\n      \"Setting num_proc from 16 back to 1 for the train split to disable multiprocessing as it only contains one shard.\\r\\n\",\n      \"Generating train split: 114599 examples [00:00, 836881.77 examples/s]\\r\\n\",\n      \"Setting num_proc from 16 back to 1 for the validation split to disable multiprocessing as it only contains one shard.\\r\\n\",\n      \"Generating validation split: 1070 examples [00:00, 252512.53 examples/s]\\r\\n\",\n      \"Setting num_proc from 16 back to 1 for the test split to disable multiprocessing as it only contains one shard.\\r\\n\",\n      \"Generating test split: 1070 examples [00:00, 313510.67 examples/s]\\r\\n\",\n      \"Map (num_proc=16): 100%|██████| 114599/114599 [00:02<00:00, 39254.76 examples/s]\\r\\n\",\n      \"train_dataset: Dataset({\\r\\n\",\n      \"    features: ['input_ids', 'labels'],\\r\\n\",\n      \"    num_rows: 114599\\r\\n\",\n      \"})\\r\\n\",\n      \"Map (num_proc=16): 100%|███████████| 1070/1070 [00:00<00:00, 1399.56 examples/s]\\r\\n\",\n      \"val_dataset: Dataset({\\r\\n\",\n      \"    features: ['input_ids', 'output_ids'],\\r\\n\",\n      \"    num_rows: 1070\\r\\n\",\n      \"})\\r\\n\",\n      \"Map (num_proc=16): 100%|███████████| 1070/1070 [00:00<00:00, 1339.19 examples/s]\\r\\n\",\n      \"test_dataset: Dataset({\\r\\n\",\n      \"    features: ['input_ids', 'output_ids'],\\r\\n\",\n      \"    num_rows: 1070\\r\\n\",\n      \"})\\r\\n\",\n      \"--> Sanity check\\r\\n\",\n      \"           '[gMASK]': 64790 -> -100\\r\\n\",\n      \"               'sop': 64792 -> -100\\r\\n\",\n      \"          '<|user|>': 64795 -> -100\\r\\n\",\n      \"                  '': 30910 -> -100\\r\\n\",\n      \"                '\\\\n': 13 -> -100\\r\\n\",\n      \"                  '': 30910 -> -100\\r\\n\",\n      \"                '类型': 33467 -> -100\\r\\n\",\n      \"                 '#': 31010 -> -100\\r\\n\",\n      \"                 '裤': 56532 -> -100\\r\\n\",\n      \"                 '*': 30998 -> -100\\r\\n\",\n      \"                 '版': 55090 -> -100\\r\\n\",\n      \"                 '型': 54888 -> -100\\r\\n\",\n      \"                 '#': 31010 -> -100\\r\\n\",\n      \"                '宽松': 40833 -> -100\\r\\n\",\n      \"                 '*': 30998 -> -100\\r\\n\",\n      \"                '风格': 32799 -> -100\\r\\n\",\n      \"                 '#': 31010 -> -100\\r\\n\",\n      \"                '性感': 40589 -> -100\\r\\n\",\n      \"                 '*': 30998 -> -100\\r\\n\",\n      \"                '图案': 37505 -> -100\\r\\n\",\n      \"                 '#': 31010 -> -100\\r\\n\",\n      \"                '线条': 37216 -> -100\\r\\n\",\n      \"                 '*': 30998 -> -100\\r\\n\",\n      \"                 '裤': 56532 -> -100\\r\\n\",\n      \"                 '型': 54888 -> -100\\r\\n\",\n      \"                 '#': 31010 -> -100\\r\\n\",\n      \"                 '阔': 56529 -> -100\\r\\n\",\n      \"                 '腿': 56158 -> -100\\r\\n\",\n      \"                 '裤': 56532 -> -100\\r\\n\",\n      \"     '<|assistant|>': 64796 -> -100\\r\\n\",\n      \"                  '': 30910 -> 30910\\r\\n\",\n      \"                '\\\\n': 13 -> 13\\r\\n\",\n      \"                  '': 30910 -> 30910\\r\\n\",\n      \"                '宽松': 40833 -> 40833\\r\\n\",\n      \"                 '的': 54530 -> 54530\\r\\n\",\n      \"                 '阔': 56529 -> 56529\\r\\n\",\n      \"                 '腿': 56158 -> 56158\\r\\n\",\n      \"                 '裤': 56532 -> 56532\\r\\n\",\n      \"                 '这': 54551 -> 54551\\r\\n\",\n      \"                '两年': 33808 -> 33808\\r\\n\",\n      \"                '真的': 32041 -> 32041\\r\\n\",\n      \"                 '吸': 55360 -> 55360\\r\\n\",\n      \"                 '粉': 55486 -> 55486\\r\\n\",\n      \"                '不少': 32138 -> 32138\\r\\n\",\n      \"                 '，': 31123 -> 31123\\r\\n\",\n      \"                '明星': 32943 -> 32943\\r\\n\",\n      \"                '时尚': 33481 -> 33481\\r\\n\",\n      \"                 '达': 54880 -> 54880\\r\\n\",\n      \"                '人的': 31664 -> 31664\\r\\n\",\n      \"                '心头': 46565 -> 46565\\r\\n\",\n      \"                 '爱': 54799 -> 54799\\r\\n\",\n      \"                 '。': 31155 -> 31155\\r\\n\",\n      \"                '毕竟': 33051 -> 33051\\r\\n\",\n      \"                 '好': 54591 -> 54591\\r\\n\",\n      \"                 '穿': 55432 -> 55432\\r\\n\",\n      \"                '时尚': 33481 -> 33481\\r\\n\",\n      \"                 '，': 31123 -> 31123\\r\\n\",\n      \"                 '谁': 55622 -> 55622\\r\\n\",\n      \"                '都能': 32904 -> 32904\\r\\n\",\n      \"                 '穿': 55432 -> 55432\\r\\n\",\n      \"                 '出': 54557 -> 54557\\r\\n\",\n      \"                 '腿': 56158 -> 56158\\r\\n\",\n      \"                 '长': 54625 -> 54625\\r\\n\",\n      \"                 '2': 30943 -> 30943\\r\\n\",\n      \"                 '米': 55055 -> 55055\\r\\n\",\n      \"               '的效果': 35590 -> 35590\\r\\n\",\n      \"                '宽松': 40833 -> 40833\\r\\n\",\n      \"                 '的': 54530 -> 54530\\r\\n\",\n      \"                 '裤': 56532 -> 56532\\r\\n\",\n      \"                 '腿': 56158 -> 56158\\r\\n\",\n      \"                 '，': 31123 -> 31123\\r\\n\",\n      \"               '当然是': 48466 -> 48466\\r\\n\",\n      \"                 '遮': 57148 -> 57148\\r\\n\",\n      \"                 '肉': 55343 -> 55343\\r\\n\",\n      \"                 '小': 54603 -> 54603\\r\\n\",\n      \"                '能手': 49355 -> 49355\\r\\n\",\n      \"                 '啊': 55674 -> 55674\\r\\n\",\n      \"                 '。': 31155 -> 31155\\r\\n\",\n      \"                '上身': 51605 -> 51605\\r\\n\",\n      \"                 '随': 55119 -> 55119\\r\\n\",\n      \"                 '性': 54642 -> 54642\\r\\n\",\n      \"                '自然': 31799 -> 31799\\r\\n\",\n      \"                 '不': 54535 -> 54535\\r\\n\",\n      \"                 '拘': 57036 -> 57036\\r\\n\",\n      \"                 '束': 55625 -> 55625\\r\\n\",\n      \"                 '，': 31123 -> 31123\\r\\n\",\n      \"                '面料': 46839 -> 46839\\r\\n\",\n      \"                 '亲': 55113 -> 55113\\r\\n\",\n      \"                 '肤': 56089 -> 56089\\r\\n\",\n      \"                '舒适': 33894 -> 33894\\r\\n\",\n      \"                 '贴': 55778 -> 55778\\r\\n\",\n      \"                '身体': 31902 -> 31902\\r\\n\",\n      \"                 '验': 55017 -> 55017\\r\\n\",\n      \"                 '感': 54706 -> 54706\\r\\n\",\n      \"                 '棒': 56382 -> 56382\\r\\n\",\n      \"                 '棒': 56382 -> 56382\\r\\n\",\n      \"                 '哒': 59230 -> 59230\\r\\n\",\n      \"                 '。': 31155 -> 31155\\r\\n\",\n      \"                 '系': 54712 -> 54712\\r\\n\",\n      \"                 '带': 54882 -> 54882\\r\\n\",\n      \"                '部分': 31726 -> 31726\\r\\n\",\n      \"                '增加': 31917 -> 31917\\r\\n\",\n      \"                '设计': 31735 -> 31735\\r\\n\",\n      \"                '看点': 45032 -> 45032\\r\\n\",\n      \"                 '，': 31123 -> 31123\\r\\n\",\n      \"                 '还': 54656 -> 54656\\r\\n\",\n      \"                 '让': 54772 -> 54772\\r\\n\",\n      \"                '单品': 46539 -> 46539\\r\\n\",\n      \"               '的设计': 34481 -> 34481\\r\\n\",\n      \"                 '感': 54706 -> 54706\\r\\n\",\n      \"                '更强': 43084 -> 43084\\r\\n\",\n      \"                 '。': 31155 -> 31155\\r\\n\",\n      \"                '腿部': 46799 -> 46799\\r\\n\",\n      \"                '线条': 37216 -> 37216\\r\\n\",\n      \"                 '若': 55351 -> 55351\\r\\n\",\n      \"                 '隐': 55733 -> 55733\\r\\n\",\n      \"                 '若': 55351 -> 55351\\r\\n\",\n      \"                 '现': 54600 -> 54600\\r\\n\",\n      \"                 '的': 54530 -> 54530\\r\\n\",\n      \"                 '，': 31123 -> 31123\\r\\n\",\n      \"                '性感': 40589 -> 40589\\r\\n\",\n      \"                 '撩': 58521 -> 58521\\r\\n\",\n      \"                 '人': 54533 -> 54533\\r\\n\",\n      \"                 '。': 31155 -> 31155\\r\\n\",\n      \"                '颜色': 33692 -> 33692\\r\\n\",\n      \"                 '敲': 57004 -> 57004\\r\\n\",\n      \"                '温柔': 34678 -> 34678\\r\\n\",\n      \"                 '的': 54530 -> 54530\\r\\n\",\n      \"                 '，': 31123 -> 31123\\r\\n\",\n      \"                 '与': 54619 -> 54619\\r\\n\",\n      \"                '裤子': 44722 -> 44722\\r\\n\",\n      \"                '本身': 32754 -> 32754\\r\\n\",\n      \"                 '所': 54626 -> 54626\\r\\n\",\n      \"                '呈现': 33169 -> 33169\\r\\n\",\n      \"               '的风格': 48084 -> 48084\\r\\n\",\n      \"                '有点': 33149 -> 33149\\r\\n\",\n      \"                 '反': 54955 -> 54955\\r\\n\",\n      \"                 '差': 55342 -> 55342\\r\\n\",\n      \"                 '萌': 56842 -> 56842\\r\\n\",\n      \"                 '。': 31155 -> 31155\\r\\n\",\n      \"                  '': 2 -> 2\\r\\n\",\n      \"/media/zr/Data/Code/ChatGLM3/venv/lib/python3.10/site-packages/accelerate/accelerator.py:436: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \\r\\n\",\n      \"dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\\r\\n\",\n      \"  warnings.warn(\\r\\n\",\n      \"max_steps is given, it will override any value given in num_train_epochs\\r\\n\",\n      \"***** Running training *****\\r\\n\",\n      \"  Num examples = 114,599\\r\\n\",\n      \"  Num Epochs = 1\\r\\n\",\n      \"  Instantaneous batch size per device = 4\\r\\n\",\n      \"  Total train batch size (w. parallel, distributed & accumulation) = 4\\r\\n\",\n      \"  Gradient Accumulation steps = 1\\r\\n\",\n      \"  Total optimization steps = 4,000\\r\\n\",\n      \"  Number of trainable parameters = 1,949,696\\r\\n\",\n      \"{'loss': 4.832, 'grad_norm': 2.1177706718444824, 'learning_rate': 4.9875000000000006e-05, 'epoch': 0.0}\\r\\n\",\n      \"{'loss': 4.6094, 'grad_norm': 3.104412078857422, 'learning_rate': 4.975e-05, 'epoch': 0.0}\\r\\n\",\n      \"{'loss': 4.5043, 'grad_norm': 2.9755077362060547, 'learning_rate': 4.962500000000001e-05, 'epoch': 0.0}\\r\\n\",\n      \"{'loss': 4.14, 'grad_norm': 3.3869752883911133, 'learning_rate': 4.9500000000000004e-05, 'epoch': 0.0}\\r\\n\",\n      \"{'loss': 4.1275, 'grad_norm': 2.698483467102051, 'learning_rate': 4.937500000000001e-05, 'epoch': 0.0}\\r\\n\",\n      \"{'loss': 3.8748, 'grad_norm': 2.9052674770355225, 'learning_rate': 4.9250000000000004e-05, 'epoch': 0.0}\\r\\n\",\n      \"{'loss': 3.8506, 'grad_norm': 2.8566994667053223, 'learning_rate': 4.9125e-05, 'epoch': 0.0}\\r\\n\",\n      \"{'loss': 3.7518, 'grad_norm': 2.9119534492492676, 'learning_rate': 4.9e-05, 'epoch': 0.0}\\r\\n\",\n      \"{'loss': 3.6375, 'grad_norm': 3.1845204830169678, 'learning_rate': 4.8875e-05, 'epoch': 0.0}\\r\\n\",\n      \"{'loss': 3.7219, 'grad_norm': 3.359720230102539, 'learning_rate': 4.875e-05, 'epoch': 0.0}\\r\\n\",\n      \"{'loss': 3.676, 'grad_norm': 3.559992790222168, 'learning_rate': 4.8625e-05, 'epoch': 0.0}\\r\\n\",\n      \"{'loss': 3.849, 'grad_norm': 3.822449207305908, 'learning_rate': 4.85e-05, 'epoch': 0.0}\\r\\n\",\n      \"{'loss': 3.6154, 'grad_norm': 3.4438886642456055, 'learning_rate': 4.8375000000000004e-05, 'epoch': 0.0}\\r\\n\",\n      \"{'loss': 3.7326, 'grad_norm': 4.374788284301758, 'learning_rate': 4.825e-05, 'epoch': 0.0}\\r\\n\",\n      \"{'loss': 3.6854, 'grad_norm': 3.5999808311462402, 'learning_rate': 4.8125000000000004e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.7447, 'grad_norm': 3.8460822105407715, 'learning_rate': 4.8e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.5766, 'grad_norm': 4.053386211395264, 'learning_rate': 4.7875000000000005e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.5758, 'grad_norm': 4.296564102172852, 'learning_rate': 4.775e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.5486, 'grad_norm': 4.701301574707031, 'learning_rate': 4.7625000000000006e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.5775, 'grad_norm': 4.4896979331970215, 'learning_rate': 4.75e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.55, 'grad_norm': 4.9407429695129395, 'learning_rate': 4.7375e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.6437, 'grad_norm': 4.0624542236328125, 'learning_rate': 4.7249999999999997e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.6098, 'grad_norm': 4.786097049713135, 'learning_rate': 4.7125e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.5107, 'grad_norm': 4.457597255706787, 'learning_rate': 4.7e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.4723, 'grad_norm': 5.279415130615234, 'learning_rate': 4.6875e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.6016, 'grad_norm': 5.297557353973389, 'learning_rate': 4.6750000000000005e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.5475, 'grad_norm': 5.397997856140137, 'learning_rate': 4.6625e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.6115, 'grad_norm': 4.472784519195557, 'learning_rate': 4.6500000000000005e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.6273, 'grad_norm': 4.7433905601501465, 'learning_rate': 4.6375e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.5379, 'grad_norm': 5.81007194519043, 'learning_rate': 4.6250000000000006e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.4654, 'grad_norm': 5.297420501708984, 'learning_rate': 4.6125e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.6057, 'grad_norm': 5.738197326660156, 'learning_rate': 4.600000000000001e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.4168, 'grad_norm': 5.207597732543945, 'learning_rate': 4.5875000000000004e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.4932, 'grad_norm': 5.2784833908081055, 'learning_rate': 4.575e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.518, 'grad_norm': 5.428376197814941, 'learning_rate': 4.5625e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.5727, 'grad_norm': 5.190096855163574, 'learning_rate': 4.55e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.3615, 'grad_norm': 4.818575859069824, 'learning_rate': 4.5375e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.5275, 'grad_norm': 5.174643039703369, 'learning_rate': 4.525e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.5232, 'grad_norm': 5.241923809051514, 'learning_rate': 4.5125e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.4699, 'grad_norm': 5.603521823883057, 'learning_rate': 4.5e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.6916, 'grad_norm': 5.468681335449219, 'learning_rate': 4.4875e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.4975, 'grad_norm': 4.969369888305664, 'learning_rate': 4.4750000000000004e-05, 'epoch': 0.01}\\r\\n\",\n      \"{'loss': 3.6207, 'grad_norm': 5.575362682342529, 'learning_rate': 4.4625e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.4152, 'grad_norm': 6.52517032623291, 'learning_rate': 4.4500000000000004e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.4098, 'grad_norm': 5.987551212310791, 'learning_rate': 4.4375e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.4244, 'grad_norm': 5.613704681396484, 'learning_rate': 4.4250000000000005e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.5303, 'grad_norm': 5.790269374847412, 'learning_rate': 4.4125e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.4475, 'grad_norm': 7.037369728088379, 'learning_rate': 4.4000000000000006e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.4562, 'grad_norm': 5.771510601043701, 'learning_rate': 4.3875e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.5623, 'grad_norm': 5.876147747039795, 'learning_rate': 4.375e-05, 'epoch': 0.02}\\r\\n\",\n      \" 12%|█████                                   | 500/4000 [04:39<37:01,  1.58it/s]***** Running Evaluation *****\\r\\n\",\n      \"  Num examples = 50\\r\\n\",\n      \"  Batch size = 16\\r\\n\",\n      \"\\r\\n\",\n      \"  0%|                                                     | 0/4 [00:00<?, ?it/s]\\u001B[A\\r\\n\",\n      \" 50%|██████████████████████▌                      | 2/4 [00:16<00:16,  8.09s/it]\\u001B[A\\r\\n\",\n      \" 75%|█████████████████████████████████▊           | 3/4 [00:32<00:11, 11.45s/it]\\u001B[A\\r\\n\",\n      \"100%|█████████████████████████████████████████████| 4/4 [00:49<00:00, 13.52s/it]\\u001B[ABuilding prefix dict from the default dictionary ...\\r\\n\",\n      \"Dumping model to file cache /tmp/jieba.cache\\r\\n\",\n      \"Loading model cost 0.580 seconds.\\r\\n\",\n      \"Prefix dict has been built successfully.\\r\\n\",\n      \"                                                                                \\r\\n\",\n      \"\\u001B[A{'eval_rouge-1': 31.645344, 'eval_rouge-2': 6.79404, 'eval_rouge-l': 23.83732, 'eval_bleu-4': 0.03250689604242964, 'eval_runtime': 54.3911, 'eval_samples_per_second': 0.919, 'eval_steps_per_second': 0.074, 'epoch': 0.02}\\r\\n\",\n      \" 12%|█████                                   | 500/4000 [05:34<37:01,  1.58it/s]\\r\\n\",\n      \"100%|█████████████████████████████████████████████| 4/4 [00:50<00:00, 13.52s/it]\\u001B[A\\r\\n\",\n      \"{'loss': 3.3207, 'grad_norm': 5.6840596199035645, 'learning_rate': 4.3625e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.5459, 'grad_norm': 6.672524929046631, 'learning_rate': 4.35e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.5822, 'grad_norm': 5.989180564880371, 'learning_rate': 4.3375000000000004e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.4859, 'grad_norm': 5.341927528381348, 'learning_rate': 4.325e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.5219, 'grad_norm': 5.3769707679748535, 'learning_rate': 4.3125000000000005e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.6453, 'grad_norm': 5.812618732452393, 'learning_rate': 4.3e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.4934, 'grad_norm': 5.726740837097168, 'learning_rate': 4.2875000000000005e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.3719, 'grad_norm': 5.551002025604248, 'learning_rate': 4.275e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.4236, 'grad_norm': 6.213701248168945, 'learning_rate': 4.2625000000000006e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.4887, 'grad_norm': 6.39825963973999, 'learning_rate': 4.25e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.4365, 'grad_norm': 6.213500499725342, 'learning_rate': 4.237500000000001e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.4559, 'grad_norm': 6.593310356140137, 'learning_rate': 4.2250000000000004e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.4463, 'grad_norm': 5.9485673904418945, 'learning_rate': 4.2125e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.4531, 'grad_norm': 6.2323737144470215, 'learning_rate': 4.2e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.5338, 'grad_norm': 5.925570964813232, 'learning_rate': 4.1875e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.4822, 'grad_norm': 6.287123203277588, 'learning_rate': 4.175e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.5402, 'grad_norm': 6.1548848152160645, 'learning_rate': 4.1625e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.3025, 'grad_norm': 6.961801052093506, 'learning_rate': 4.15e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.4016, 'grad_norm': 6.60474967956543, 'learning_rate': 4.1375e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.3547, 'grad_norm': 6.296048641204834, 'learning_rate': 4.125e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.4992, 'grad_norm': 7.013551712036133, 'learning_rate': 4.1125000000000004e-05, 'epoch': 0.02}\\r\\n\",\n      \"{'loss': 3.5275, 'grad_norm': 6.747519493103027, 'learning_rate': 4.1e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.2475, 'grad_norm': 6.900665283203125, 'learning_rate': 4.0875000000000004e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.5727, 'grad_norm': 5.7873334884643555, 'learning_rate': 4.075e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.3961, 'grad_norm': 6.46198844909668, 'learning_rate': 4.0625000000000005e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.4777, 'grad_norm': 6.117852687835693, 'learning_rate': 4.05e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.6215, 'grad_norm': 6.421164035797119, 'learning_rate': 4.0375e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.4736, 'grad_norm': 6.280588626861572, 'learning_rate': 4.025e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.3248, 'grad_norm': 6.418524265289307, 'learning_rate': 4.0125e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.5496, 'grad_norm': 6.983282089233398, 'learning_rate': 4e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.2926, 'grad_norm': 6.696746349334717, 'learning_rate': 3.9875e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.3609, 'grad_norm': 6.474392414093018, 'learning_rate': 3.9750000000000004e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.458, 'grad_norm': 7.111743450164795, 'learning_rate': 3.9625e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.4062, 'grad_norm': 6.317008018493652, 'learning_rate': 3.9500000000000005e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.5057, 'grad_norm': 6.232912540435791, 'learning_rate': 3.9375e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.5305, 'grad_norm': 6.192782402038574, 'learning_rate': 3.9250000000000005e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.2908, 'grad_norm': 7.155930042266846, 'learning_rate': 3.9125e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.4904, 'grad_norm': 6.664801597595215, 'learning_rate': 3.9000000000000006e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.4529, 'grad_norm': 7.4175615310668945, 'learning_rate': 3.8875e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.2643, 'grad_norm': 7.862004280090332, 'learning_rate': 3.875e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.4562, 'grad_norm': 7.8772687911987305, 'learning_rate': 3.8625e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.4186, 'grad_norm': 6.901059150695801, 'learning_rate': 3.85e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.4582, 'grad_norm': 7.472389221191406, 'learning_rate': 3.8375e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.5643, 'grad_norm': 7.333090305328369, 'learning_rate': 3.825e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.3639, 'grad_norm': 6.445948600769043, 'learning_rate': 3.8125e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.4389, 'grad_norm': 7.957160949707031, 'learning_rate': 3.8e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.5336, 'grad_norm': 5.9428324699401855, 'learning_rate': 3.7875e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.3242, 'grad_norm': 6.897878646850586, 'learning_rate': 3.775e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.4594, 'grad_norm': 7.274386882781982, 'learning_rate': 3.7625e-05, 'epoch': 0.03}\\r\\n\",\n      \"{'loss': 3.3949, 'grad_norm': 7.8012471199035645, 'learning_rate': 3.7500000000000003e-05, 'epoch': 0.03}\\r\\n\",\n      \" 25%|█████████▊                             | 1000/4000 [10:11<28:52,  1.73it/s]***** Running Evaluation *****\\r\\n\",\n      \"  Num examples = 50\\r\\n\",\n      \"  Batch size = 16\\r\\n\",\n      \"\\r\\n\",\n      \"  0%|                                                     | 0/4 [00:00<?, ?it/s]\\u001B[A\\r\\n\",\n      \" 50%|██████████████████████▌                      | 2/4 [00:03<00:03,  1.53s/it]\\u001B[A\\r\\n\",\n      \" 75%|█████████████████████████████████▊           | 3/4 [00:05<00:01,  1.97s/it]\\u001B[A\\r\\n\",\n      \"                                                                                \\u001B[A\\r\\n\",\n      \"\\u001B[A{'eval_rouge-1': 32.134831999999996, 'eval_rouge-2': 6.325576000000001, 'eval_rouge-l': 25.315346000000005, 'eval_bleu-4': 0.03137707571044217, 'eval_runtime': 9.9272, 'eval_samples_per_second': 5.037, 'eval_steps_per_second': 0.403, 'epoch': 0.03}\\r\\n\",\n      \" 25%|█████████▊                             | 1000/4000 [10:21<28:52,  1.73it/s]\\r\\n\",\n      \"100%|█████████████████████████████████████████████| 4/4 [00:07<00:00,  1.77s/it]\\u001B[A\\r\\n\",\n      \"{'loss': 3.4504, 'grad_norm': 6.908702373504639, 'learning_rate': 3.737500000000001e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.4596, 'grad_norm': 7.377086639404297, 'learning_rate': 3.7250000000000004e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.6484, 'grad_norm': 8.061379432678223, 'learning_rate': 3.7125e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.4, 'grad_norm': 6.452291011810303, 'learning_rate': 3.7e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.3891, 'grad_norm': 8.560649871826172, 'learning_rate': 3.6875e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.3551, 'grad_norm': 7.644310474395752, 'learning_rate': 3.675e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.3895, 'grad_norm': 7.036133766174316, 'learning_rate': 3.6625e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.4611, 'grad_norm': 7.2408528327941895, 'learning_rate': 3.65e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.5271, 'grad_norm': 7.058151721954346, 'learning_rate': 3.6375e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.4662, 'grad_norm': 6.564244747161865, 'learning_rate': 3.625e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.3428, 'grad_norm': 6.844818115234375, 'learning_rate': 3.6125000000000004e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.5244, 'grad_norm': 7.949232578277588, 'learning_rate': 3.6e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.4357, 'grad_norm': 7.32559871673584, 'learning_rate': 3.5875000000000005e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.3572, 'grad_norm': 8.051689147949219, 'learning_rate': 3.575e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.3174, 'grad_norm': 7.550294399261475, 'learning_rate': 3.5625000000000005e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.3588, 'grad_norm': 7.240135669708252, 'learning_rate': 3.55e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.4516, 'grad_norm': 6.720525741577148, 'learning_rate': 3.5375e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.4717, 'grad_norm': 6.3586320877075195, 'learning_rate': 3.525e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.3574, 'grad_norm': 6.693387985229492, 'learning_rate': 3.5125e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.407, 'grad_norm': 6.322566509246826, 'learning_rate': 3.5e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.2439, 'grad_norm': 6.481217384338379, 'learning_rate': 3.4875e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.3391, 'grad_norm': 7.359728813171387, 'learning_rate': 3.475e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.3771, 'grad_norm': 7.4071478843688965, 'learning_rate': 3.4625e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.3758, 'grad_norm': 7.325416564941406, 'learning_rate': 3.45e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.4434, 'grad_norm': 6.780652046203613, 'learning_rate': 3.4375e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.2818, 'grad_norm': 7.619284152984619, 'learning_rate': 3.4250000000000006e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.4562, 'grad_norm': 7.123080253601074, 'learning_rate': 3.4125e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.3322, 'grad_norm': 7.0780863761901855, 'learning_rate': 3.4000000000000007e-05, 'epoch': 0.04}\\r\\n\",\n      \"{'loss': 3.3887, 'grad_norm': 6.898688316345215, 'learning_rate': 3.3875000000000003e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.4793, 'grad_norm': 7.293100357055664, 'learning_rate': 3.375000000000001e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.4607, 'grad_norm': 6.927903175354004, 'learning_rate': 3.3625000000000004e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.4535, 'grad_norm': 6.639427661895752, 'learning_rate': 3.35e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.4008, 'grad_norm': 10.613078117370605, 'learning_rate': 3.3375e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.3059, 'grad_norm': 7.491557598114014, 'learning_rate': 3.325e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.3484, 'grad_norm': 7.497087001800537, 'learning_rate': 3.3125e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.2969, 'grad_norm': 8.017332077026367, 'learning_rate': 3.3e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.5152, 'grad_norm': 7.311262130737305, 'learning_rate': 3.2875e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.3871, 'grad_norm': 7.2260003089904785, 'learning_rate': 3.275e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.3563, 'grad_norm': 7.222864151000977, 'learning_rate': 3.2625e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.4166, 'grad_norm': 6.612077713012695, 'learning_rate': 3.2500000000000004e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.3465, 'grad_norm': 7.431714057922363, 'learning_rate': 3.2375e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.2621, 'grad_norm': 7.619777202606201, 'learning_rate': 3.2250000000000005e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.3795, 'grad_norm': 7.628826141357422, 'learning_rate': 3.2125e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.3551, 'grad_norm': 7.093392848968506, 'learning_rate': 3.2000000000000005e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.2658, 'grad_norm': 6.70922327041626, 'learning_rate': 3.1875e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.3914, 'grad_norm': 7.325173377990723, 'learning_rate': 3.175e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.4367, 'grad_norm': 9.542543411254883, 'learning_rate': 3.1624999999999996e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.2979, 'grad_norm': 6.646926403045654, 'learning_rate': 3.15e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.4375, 'grad_norm': 7.366168975830078, 'learning_rate': 3.1375e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.4574, 'grad_norm': 6.800962924957275, 'learning_rate': 3.125e-05, 'epoch': 0.05}\\r\\n\",\n      \" 38%|██████████████▋                        | 1500/4000 [14:57<20:28,  2.03it/s]***** Running Evaluation *****\\r\\n\",\n      \"  Num examples = 50\\r\\n\",\n      \"  Batch size = 16\\r\\n\",\n      \"\\r\\n\",\n      \"  0%|                                                     | 0/4 [00:00<?, ?it/s]\\u001B[A\\r\\n\",\n      \" 50%|██████████████████████▌                      | 2/4 [00:02<00:02,  1.43s/it]\\u001B[A\\r\\n\",\n      \" 75%|█████████████████████████████████▊           | 3/4 [00:18<00:07,  7.54s/it]\\u001B[A\\r\\n\",\n      \"                                                                                \\u001B[A\\r\\n\",\n      \"\\u001B[A{'eval_rouge-1': 31.905676000000007, 'eval_rouge-2': 6.630377999999999, 'eval_rouge-l': 25.126853999999998, 'eval_bleu-4': 0.03152151596531457, 'eval_runtime': 23.6793, 'eval_samples_per_second': 2.112, 'eval_steps_per_second': 0.169, 'epoch': 0.05}\\r\\n\",\n      \" 38%|██████████████▋                        | 1500/4000 [15:21<20:28,  2.03it/s]\\r\\n\",\n      \"100%|█████████████████████████████████████████████| 4/4 [00:20<00:00,  5.41s/it]\\u001B[A\\r\\n\",\n      \"{'loss': 3.3451, 'grad_norm': 6.90294075012207, 'learning_rate': 3.1125000000000004e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.3844, 'grad_norm': 8.37482738494873, 'learning_rate': 3.1e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.4359, 'grad_norm': 8.105109214782715, 'learning_rate': 3.0875000000000005e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.3988, 'grad_norm': 7.031566143035889, 'learning_rate': 3.075e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.4945, 'grad_norm': 7.260471343994141, 'learning_rate': 3.0625000000000006e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.4061, 'grad_norm': 8.252367973327637, 'learning_rate': 3.05e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.4643, 'grad_norm': 7.982962131500244, 'learning_rate': 3.0375000000000003e-05, 'epoch': 0.05}\\r\\n\",\n      \"{'loss': 3.4326, 'grad_norm': 7.5859808921813965, 'learning_rate': 3.025e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.5098, 'grad_norm': 9.218013763427734, 'learning_rate': 3.0125000000000004e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.3924, 'grad_norm': 7.129590034484863, 'learning_rate': 3e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.3645, 'grad_norm': 7.882465362548828, 'learning_rate': 2.9875000000000004e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.3656, 'grad_norm': 8.374431610107422, 'learning_rate': 2.975e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.4676, 'grad_norm': 7.145497798919678, 'learning_rate': 2.9625000000000002e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.3199, 'grad_norm': 7.946256160736084, 'learning_rate': 2.95e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.3682, 'grad_norm': 7.46930456161499, 'learning_rate': 2.9375000000000003e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.2996, 'grad_norm': 6.9753265380859375, 'learning_rate': 2.925e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.475, 'grad_norm': 8.484821319580078, 'learning_rate': 2.9125000000000003e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.3715, 'grad_norm': 7.118030548095703, 'learning_rate': 2.9e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.3742, 'grad_norm': 7.3347368240356445, 'learning_rate': 2.8875e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.5146, 'grad_norm': 6.8588714599609375, 'learning_rate': 2.8749999999999997e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.4602, 'grad_norm': 7.292227745056152, 'learning_rate': 2.8625e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.499, 'grad_norm': 7.423632621765137, 'learning_rate': 2.8499999999999998e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.4059, 'grad_norm': 7.430981636047363, 'learning_rate': 2.8375000000000002e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.398, 'grad_norm': 7.364171981811523, 'learning_rate': 2.825e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.4631, 'grad_norm': 7.548583984375, 'learning_rate': 2.8125000000000003e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.442, 'grad_norm': 7.765754699707031, 'learning_rate': 2.8000000000000003e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.3605, 'grad_norm': 8.27833366394043, 'learning_rate': 2.7875e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.3459, 'grad_norm': 8.09084415435791, 'learning_rate': 2.7750000000000004e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.3928, 'grad_norm': 8.150015830993652, 'learning_rate': 2.7625e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.3408, 'grad_norm': 7.760500907897949, 'learning_rate': 2.7500000000000004e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.3803, 'grad_norm': 8.982950210571289, 'learning_rate': 2.7375e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.3381, 'grad_norm': 7.609743118286133, 'learning_rate': 2.725e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.5785, 'grad_norm': 7.900216102600098, 'learning_rate': 2.7125000000000002e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.3395, 'grad_norm': 8.472111701965332, 'learning_rate': 2.7000000000000002e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.4895, 'grad_norm': 8.781264305114746, 'learning_rate': 2.6875e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.3846, 'grad_norm': 7.472824573516846, 'learning_rate': 2.6750000000000003e-05, 'epoch': 0.06}\\r\\n\",\n      \"{'loss': 3.3115, 'grad_norm': 8.073516845703125, 'learning_rate': 2.6625e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.3037, 'grad_norm': 7.2763519287109375, 'learning_rate': 2.6500000000000004e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.3965, 'grad_norm': 7.201462268829346, 'learning_rate': 2.6375e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.3717, 'grad_norm': 7.831448554992676, 'learning_rate': 2.625e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.391, 'grad_norm': 7.940402507781982, 'learning_rate': 2.6124999999999998e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.477, 'grad_norm': 7.303577899932861, 'learning_rate': 2.6000000000000002e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.2766, 'grad_norm': 7.596188545227051, 'learning_rate': 2.5875e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.4998, 'grad_norm': 7.545307159423828, 'learning_rate': 2.5750000000000002e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.3592, 'grad_norm': 6.786509990692139, 'learning_rate': 2.5625e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.2854, 'grad_norm': 8.573935508728027, 'learning_rate': 2.5500000000000003e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.3727, 'grad_norm': 7.578614234924316, 'learning_rate': 2.5375e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.2307, 'grad_norm': 7.565990447998047, 'learning_rate': 2.525e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.41, 'grad_norm': 7.094372749328613, 'learning_rate': 2.5124999999999997e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.4619, 'grad_norm': 7.98245096206665, 'learning_rate': 2.5e-05, 'epoch': 0.07}\\r\\n\",\n      \" 50%|███████████████████▌                   | 2000/4000 [19:57<17:54,  1.86it/s]***** Running Evaluation *****\\r\\n\",\n      \"  Num examples = 50\\r\\n\",\n      \"  Batch size = 16\\r\\n\",\n      \"\\r\\n\",\n      \"  0%|                                                     | 0/4 [00:00<?, ?it/s]\\u001B[A\\r\\n\",\n      \" 50%|██████████████████████▌                      | 2/4 [00:16<00:16,  8.01s/it]\\u001B[A\\r\\n\",\n      \" 75%|█████████████████████████████████▊           | 3/4 [00:32<00:11, 11.33s/it]\\u001B[A\\r\\n\",\n      \"                                                                                \\u001B[A\\r\\n\",\n      \"\\u001B[A{'eval_rouge-1': 31.442076, 'eval_rouge-2': 7.156823999999999, 'eval_rouge-l': 23.246924000000003, 'eval_bleu-4': 0.03405216374744, 'eval_runtime': 64.2793, 'eval_samples_per_second': 0.778, 'eval_steps_per_second': 0.062, 'epoch': 0.07}\\r\\n\",\n      \" 50%|███████████████████▌                   | 2000/4000 [21:01<17:54,  1.86it/s]\\r\\n\",\n      \"100%|█████████████████████████████████████████████| 4/4 [00:48<00:00, 12.97s/it]\\u001B[A\\r\\n\",\n      \"                                                                                \\u001B[ASaving model checkpoint to ./output/checkpoint-2000\\r\\n\",\n      \"/media/zr/Data/Code/ChatGLM3/venv/lib/python3.10/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /media/zr/Data/Models/LLM/chatglm3-6b - will assume that the vocabulary was not modified.\\r\\n\",\n      \"  warnings.warn(\\r\\n\",\n      \"{'loss': 3.3818, 'grad_norm': 8.677833557128906, 'learning_rate': 2.4875e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.4928, 'grad_norm': 7.391153812408447, 'learning_rate': 2.4750000000000002e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.5547, 'grad_norm': 8.77245044708252, 'learning_rate': 2.4625000000000002e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.4939, 'grad_norm': 8.10531997680664, 'learning_rate': 2.45e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.3687, 'grad_norm': 8.14376449584961, 'learning_rate': 2.4375e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.3307, 'grad_norm': 7.644017219543457, 'learning_rate': 2.425e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.4414, 'grad_norm': 7.982100486755371, 'learning_rate': 2.4125e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.4115, 'grad_norm': 8.171486854553223, 'learning_rate': 2.4e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.4326, 'grad_norm': 7.437331199645996, 'learning_rate': 2.3875e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.3533, 'grad_norm': 7.70622444152832, 'learning_rate': 2.375e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.2926, 'grad_norm': 7.60914945602417, 'learning_rate': 2.3624999999999998e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.5812, 'grad_norm': 8.040843963623047, 'learning_rate': 2.35e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.2502, 'grad_norm': 7.3959574699401855, 'learning_rate': 2.3375000000000002e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.3521, 'grad_norm': 8.238727569580078, 'learning_rate': 2.3250000000000003e-05, 'epoch': 0.07}\\r\\n\",\n      \"{'loss': 3.3969, 'grad_norm': 7.359251022338867, 'learning_rate': 2.3125000000000003e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.5178, 'grad_norm': 8.128018379211426, 'learning_rate': 2.3000000000000003e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.393, 'grad_norm': 7.082696914672852, 'learning_rate': 2.2875e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.4172, 'grad_norm': 7.790773868560791, 'learning_rate': 2.275e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.3604, 'grad_norm': 7.583011150360107, 'learning_rate': 2.2625e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.4316, 'grad_norm': 7.347414970397949, 'learning_rate': 2.25e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.4496, 'grad_norm': 6.759352207183838, 'learning_rate': 2.2375000000000002e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.4145, 'grad_norm': 7.640699863433838, 'learning_rate': 2.2250000000000002e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.4189, 'grad_norm': 8.391305923461914, 'learning_rate': 2.2125000000000002e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.3705, 'grad_norm': 8.04839038848877, 'learning_rate': 2.2000000000000003e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.2355, 'grad_norm': 8.35435962677002, 'learning_rate': 2.1875e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.3584, 'grad_norm': 7.815989017486572, 'learning_rate': 2.175e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.4268, 'grad_norm': 8.53368854522705, 'learning_rate': 2.1625e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.467, 'grad_norm': 7.677575588226318, 'learning_rate': 2.15e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.2885, 'grad_norm': 8.361733436584473, 'learning_rate': 2.1375e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.3535, 'grad_norm': 8.110257148742676, 'learning_rate': 2.125e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.3191, 'grad_norm': 8.498170852661133, 'learning_rate': 2.1125000000000002e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.3271, 'grad_norm': 8.709260940551758, 'learning_rate': 2.1e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.3629, 'grad_norm': 9.01534366607666, 'learning_rate': 2.0875e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.3635, 'grad_norm': 7.54719352722168, 'learning_rate': 2.075e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.2623, 'grad_norm': 8.59843635559082, 'learning_rate': 2.0625e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.3803, 'grad_norm': 8.170056343078613, 'learning_rate': 2.05e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.3506, 'grad_norm': 7.873594284057617, 'learning_rate': 2.0375e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.4871, 'grad_norm': 8.418689727783203, 'learning_rate': 2.025e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.2262, 'grad_norm': 8.624137878417969, 'learning_rate': 2.0125e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.4514, 'grad_norm': 7.584123611450195, 'learning_rate': 2e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.4514, 'grad_norm': 7.975276470184326, 'learning_rate': 1.9875000000000002e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.2789, 'grad_norm': 7.9726481437683105, 'learning_rate': 1.9750000000000002e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.3652, 'grad_norm': 7.4362945556640625, 'learning_rate': 1.9625000000000003e-05, 'epoch': 0.08}\\r\\n\",\n      \"{'loss': 3.3795, 'grad_norm': 8.107170104980469, 'learning_rate': 1.9500000000000003e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.2727, 'grad_norm': 7.757025241851807, 'learning_rate': 1.9375e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.3055, 'grad_norm': 7.5721869468688965, 'learning_rate': 1.925e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.2545, 'grad_norm': 8.496746063232422, 'learning_rate': 1.9125e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.4332, 'grad_norm': 7.52405309677124, 'learning_rate': 1.9e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.4711, 'grad_norm': 7.90508508682251, 'learning_rate': 1.8875e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.39, 'grad_norm': 9.309752464294434, 'learning_rate': 1.8750000000000002e-05, 'epoch': 0.09}\\r\\n\",\n      \" 62%|████████████████████████▍              | 2500/4000 [25:37<13:33,  1.84it/s]***** Running Evaluation *****\\r\\n\",\n      \"  Num examples = 50\\r\\n\",\n      \"  Batch size = 16\\r\\n\",\n      \"\\r\\n\",\n      \"  0%|                                                     | 0/4 [00:00<?, ?it/s]\\u001B[A\\r\\n\",\n      \" 50%|██████████████████████▌                      | 2/4 [00:03<00:03,  1.72s/it]\\u001B[A\\r\\n\",\n      \" 75%|█████████████████████████████████▊           | 3/4 [00:06<00:02,  2.25s/it]\\u001B[A\\r\\n\",\n      \"                                                                                \\u001B[A\\r\\n\",\n      \"\\u001B[A{'eval_rouge-1': 31.633207999999996, 'eval_rouge-2': 6.800014, 'eval_rouge-l': 25.123896000000006, 'eval_bleu-4': 0.03327400496195634, 'eval_runtime': 25.5968, 'eval_samples_per_second': 1.953, 'eval_steps_per_second': 0.156, 'epoch': 0.09}\\r\\n\",\n      \" 62%|████████████████████████▍              | 2500/4000 [26:03<13:33,  1.84it/s]\\r\\n\",\n      \"100%|█████████████████████████████████████████████| 4/4 [00:22<00:00,  7.31s/it]\\u001B[A\\r\\n\",\n      \"{'loss': 3.2988, 'grad_norm': 8.42829704284668, 'learning_rate': 1.8625000000000002e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.3408, 'grad_norm': 9.460935592651367, 'learning_rate': 1.85e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.2467, 'grad_norm': 7.881652355194092, 'learning_rate': 1.8375e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.3906, 'grad_norm': 8.49362564086914, 'learning_rate': 1.825e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.3859, 'grad_norm': 7.6069016456604, 'learning_rate': 1.8125e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.3982, 'grad_norm': 8.237305641174316, 'learning_rate': 1.8e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.465, 'grad_norm': 7.80671501159668, 'learning_rate': 1.7875e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.4805, 'grad_norm': 8.655023574829102, 'learning_rate': 1.775e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.3734, 'grad_norm': 8.358222961425781, 'learning_rate': 1.7625e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.4732, 'grad_norm': 8.640260696411133, 'learning_rate': 1.75e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.3471, 'grad_norm': 8.130788803100586, 'learning_rate': 1.7375e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.4129, 'grad_norm': 7.604771614074707, 'learning_rate': 1.725e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.5184, 'grad_norm': 7.612947463989258, 'learning_rate': 1.7125000000000003e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.4441, 'grad_norm': 8.518109321594238, 'learning_rate': 1.7000000000000003e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.3992, 'grad_norm': 7.822119235992432, 'learning_rate': 1.6875000000000004e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.3439, 'grad_norm': 7.961773872375488, 'learning_rate': 1.675e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.4062, 'grad_norm': 8.931722640991211, 'learning_rate': 1.6625e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.2609, 'grad_norm': 7.5368194580078125, 'learning_rate': 1.65e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.4715, 'grad_norm': 8.477120399475098, 'learning_rate': 1.6375e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.4461, 'grad_norm': 9.24991512298584, 'learning_rate': 1.6250000000000002e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.4182, 'grad_norm': 8.294699668884277, 'learning_rate': 1.6125000000000002e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.2432, 'grad_norm': 7.574826717376709, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.09}\\r\\n\",\n      \"{'loss': 3.3834, 'grad_norm': 8.255449295043945, 'learning_rate': 1.5875e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.385, 'grad_norm': 8.229700088500977, 'learning_rate': 1.575e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.449, 'grad_norm': 8.934239387512207, 'learning_rate': 1.5625e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.3947, 'grad_norm': 8.390064239501953, 'learning_rate': 1.55e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.3486, 'grad_norm': 8.181641578674316, 'learning_rate': 1.5375e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.2568, 'grad_norm': 8.498324394226074, 'learning_rate': 1.525e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.2709, 'grad_norm': 7.9656147956848145, 'learning_rate': 1.5125e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.2258, 'grad_norm': 7.652721405029297, 'learning_rate': 1.5e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.4379, 'grad_norm': 8.255173683166504, 'learning_rate': 1.4875e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.3639, 'grad_norm': 7.929840564727783, 'learning_rate': 1.475e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.3836, 'grad_norm': 8.210647583007812, 'learning_rate': 1.4625e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.4367, 'grad_norm': 8.759031295776367, 'learning_rate': 1.45e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.4047, 'grad_norm': 8.681133270263672, 'learning_rate': 1.4374999999999999e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.327, 'grad_norm': 8.468674659729004, 'learning_rate': 1.4249999999999999e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.3654, 'grad_norm': 8.48736572265625, 'learning_rate': 1.4125e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.5008, 'grad_norm': 9.581798553466797, 'learning_rate': 1.4000000000000001e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.2943, 'grad_norm': 8.112646102905273, 'learning_rate': 1.3875000000000002e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.3182, 'grad_norm': 8.913463592529297, 'learning_rate': 1.3750000000000002e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.2932, 'grad_norm': 7.881869792938232, 'learning_rate': 1.3625e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.2365, 'grad_norm': 7.5258941650390625, 'learning_rate': 1.3500000000000001e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.3527, 'grad_norm': 9.253165245056152, 'learning_rate': 1.3375000000000002e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.248, 'grad_norm': 8.01251220703125, 'learning_rate': 1.3250000000000002e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.36, 'grad_norm': 8.332780838012695, 'learning_rate': 1.3125e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.2068, 'grad_norm': 9.181897163391113, 'learning_rate': 1.3000000000000001e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.4514, 'grad_norm': 8.965094566345215, 'learning_rate': 1.2875000000000001e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.424, 'grad_norm': 8.944855690002441, 'learning_rate': 1.2750000000000002e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.4562, 'grad_norm': 8.20882511138916, 'learning_rate': 1.2625e-05, 'epoch': 0.1}\\r\\n\",\n      \"{'loss': 3.358, 'grad_norm': 7.769922733306885, 'learning_rate': 1.25e-05, 'epoch': 0.1}\\r\\n\",\n      \" 75%|█████████████████████████████▎         | 3000/4000 [30:40<08:42,  1.91it/s]***** Running Evaluation *****\\r\\n\",\n      \"  Num examples = 50\\r\\n\",\n      \"  Batch size = 16\\r\\n\",\n      \"\\r\\n\",\n      \"  0%|                                                     | 0/4 [00:00<?, ?it/s]\\u001B[A\\r\\n\",\n      \" 50%|██████████████████████▌                      | 2/4 [00:02<00:02,  1.43s/it]\\u001B[A\\r\\n\",\n      \" 75%|█████████████████████████████████▊           | 3/4 [00:05<00:01,  1.94s/it]\\u001B[A\\r\\n\",\n      \"                                                                                \\u001B[A\\r\\n\",\n      \"\\u001B[A{'eval_rouge-1': 33.007998, 'eval_rouge-2': 7.157356, 'eval_rouge-l': 25.306306000000003, 'eval_bleu-4': 0.0348571644891679, 'eval_runtime': 38.0831, 'eval_samples_per_second': 1.313, 'eval_steps_per_second': 0.105, 'epoch': 0.1}\\r\\n\",\n      \" 75%|█████████████████████████████▎         | 3000/4000 [31:18<08:42,  1.91it/s]\\r\\n\",\n      \"100%|█████████████████████████████████████████████| 4/4 [00:21<00:00,  7.25s/it]\\u001B[A\\r\\n\",\n      \"{'loss': 3.4711, 'grad_norm': 8.417685508728027, 'learning_rate': 1.2375000000000001e-05, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.3418, 'grad_norm': 8.048948287963867, 'learning_rate': 1.225e-05, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.3564, 'grad_norm': 8.270435333251953, 'learning_rate': 1.2125e-05, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.2293, 'grad_norm': 7.761234760284424, 'learning_rate': 1.2e-05, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.3873, 'grad_norm': 8.1546049118042, 'learning_rate': 1.1875e-05, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.5338, 'grad_norm': 7.905092239379883, 'learning_rate': 1.175e-05, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.2963, 'grad_norm': 8.120687484741211, 'learning_rate': 1.1625000000000001e-05, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.292, 'grad_norm': 9.561246871948242, 'learning_rate': 1.1500000000000002e-05, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.2029, 'grad_norm': 9.09880542755127, 'learning_rate': 1.1375e-05, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.3873, 'grad_norm': 7.879208087921143, 'learning_rate': 1.125e-05, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.3383, 'grad_norm': 8.732316970825195, 'learning_rate': 1.1125000000000001e-05, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.3205, 'grad_norm': 8.577627182006836, 'learning_rate': 1.1000000000000001e-05, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.3717, 'grad_norm': 9.737064361572266, 'learning_rate': 1.0875e-05, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.2996, 'grad_norm': 8.619685173034668, 'learning_rate': 1.075e-05, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.4496, 'grad_norm': 8.600975036621094, 'learning_rate': 1.0625e-05, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.4277, 'grad_norm': 8.75851821899414, 'learning_rate': 1.05e-05, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.4809, 'grad_norm': 7.5685930252075195, 'learning_rate': 1.0375e-05, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.226, 'grad_norm': 8.321500778198242, 'learning_rate': 1.025e-05, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.3586, 'grad_norm': 7.587204933166504, 'learning_rate': 1.0125e-05, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.4166, 'grad_norm': 8.86058235168457, 'learning_rate': 1e-05, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.382, 'grad_norm': 9.254091262817383, 'learning_rate': 9.875000000000001e-06, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.3961, 'grad_norm': 7.718448162078857, 'learning_rate': 9.750000000000002e-06, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.4699, 'grad_norm': 8.792988777160645, 'learning_rate': 9.625e-06, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.2145, 'grad_norm': 8.899701118469238, 'learning_rate': 9.5e-06, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.4141, 'grad_norm': 8.802495956420898, 'learning_rate': 9.375000000000001e-06, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.3627, 'grad_norm': 9.895890235900879, 'learning_rate': 9.25e-06, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.4182, 'grad_norm': 8.153362274169922, 'learning_rate': 9.125e-06, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.2916, 'grad_norm': 8.173482894897461, 'learning_rate': 9e-06, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.2963, 'grad_norm': 9.929978370666504, 'learning_rate': 8.875e-06, 'epoch': 0.11}\\r\\n\",\n      \"{'loss': 3.4039, 'grad_norm': 7.541258335113525, 'learning_rate': 8.75e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.3602, 'grad_norm': 7.881056785583496, 'learning_rate': 8.625e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.2324, 'grad_norm': 8.763860702514648, 'learning_rate': 8.500000000000002e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.4018, 'grad_norm': 9.141348838806152, 'learning_rate': 8.375e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.3771, 'grad_norm': 8.166316032409668, 'learning_rate': 8.25e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.2783, 'grad_norm': 9.261619567871094, 'learning_rate': 8.125000000000001e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.4312, 'grad_norm': 8.153901100158691, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.327, 'grad_norm': 7.708031177520752, 'learning_rate': 7.875e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.3779, 'grad_norm': 7.920627117156982, 'learning_rate': 7.75e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.2857, 'grad_norm': 9.732666015625, 'learning_rate': 7.625e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.3588, 'grad_norm': 8.037003517150879, 'learning_rate': 7.5e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.2002, 'grad_norm': 8.716700553894043, 'learning_rate': 7.375e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.2863, 'grad_norm': 9.12403678894043, 'learning_rate': 7.25e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.3447, 'grad_norm': 8.44495677947998, 'learning_rate': 7.1249999999999995e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.3088, 'grad_norm': 8.425846099853516, 'learning_rate': 7.000000000000001e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.3281, 'grad_norm': 8.53967571258545, 'learning_rate': 6.875000000000001e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.3451, 'grad_norm': 9.039155960083008, 'learning_rate': 6.750000000000001e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.2674, 'grad_norm': 9.248905181884766, 'learning_rate': 6.625000000000001e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.2703, 'grad_norm': 10.257024765014648, 'learning_rate': 6.5000000000000004e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.4084, 'grad_norm': 8.447395324707031, 'learning_rate': 6.375000000000001e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.4488, 'grad_norm': 8.430671691894531, 'learning_rate': 6.25e-06, 'epoch': 0.12}\\r\\n\",\n      \" 88%|██████████████████████████████████▏    | 3500/4000 [35:52<04:30,  1.85it/s]***** Running Evaluation *****\\r\\n\",\n      \"  Num examples = 50\\r\\n\",\n      \"  Batch size = 16\\r\\n\",\n      \"\\r\\n\",\n      \"  0%|                                                     | 0/4 [00:00<?, ?it/s]\\u001B[A\\r\\n\",\n      \" 50%|██████████████████████▌                      | 2/4 [00:04<00:04,  2.18s/it]\\u001B[A\\r\\n\",\n      \" 75%|█████████████████████████████████▊           | 3/4 [00:06<00:02,  2.23s/it]\\u001B[A\\r\\n\",\n      \"                                                                                \\u001B[A\\r\\n\",\n      \"\\u001B[A{'eval_rouge-1': 32.222722, 'eval_rouge-2': 6.6331180000000005, 'eval_rouge-l': 25.087382, 'eval_bleu-4': 0.03253227960558209, 'eval_runtime': 25.0679, 'eval_samples_per_second': 1.995, 'eval_steps_per_second': 0.16, 'epoch': 0.12}\\r\\n\",\n      \" 88%|██████████████████████████████████▏    | 3500/4000 [36:17<04:30,  1.85it/s]\\r\\n\",\n      \"100%|█████████████████████████████████████████████| 4/4 [00:08<00:00,  2.14s/it]\\u001B[A\\r\\n\",\n      \"{'loss': 3.3912, 'grad_norm': 9.152791976928711, 'learning_rate': 6.125e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.3229, 'grad_norm': 9.17188549041748, 'learning_rate': 6e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.2846, 'grad_norm': 8.172340393066406, 'learning_rate': 5.875e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.308, 'grad_norm': 8.928167343139648, 'learning_rate': 5.750000000000001e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.3578, 'grad_norm': 8.738048553466797, 'learning_rate': 5.625e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.2117, 'grad_norm': 8.161530494689941, 'learning_rate': 5.500000000000001e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.3182, 'grad_norm': 7.672643184661865, 'learning_rate': 5.375e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.4324, 'grad_norm': 9.408201217651367, 'learning_rate': 5.25e-06, 'epoch': 0.12}\\r\\n\",\n      \"{'loss': 3.2418, 'grad_norm': 9.635400772094727, 'learning_rate': 5.125e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.1869, 'grad_norm': 8.71308708190918, 'learning_rate': 5e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.2719, 'grad_norm': 10.24747085571289, 'learning_rate': 4.875000000000001e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.5238, 'grad_norm': 8.207618713378906, 'learning_rate': 4.75e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.3982, 'grad_norm': 9.101743698120117, 'learning_rate': 4.625e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.2068, 'grad_norm': 9.008282661437988, 'learning_rate': 4.5e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.3084, 'grad_norm': 9.63040828704834, 'learning_rate': 4.375e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.1973, 'grad_norm': 8.8562593460083, 'learning_rate': 4.250000000000001e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.298, 'grad_norm': 8.217488288879395, 'learning_rate': 4.125e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.3773, 'grad_norm': 8.624151229858398, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.3955, 'grad_norm': 8.07646369934082, 'learning_rate': 3.875e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.4082, 'grad_norm': 9.692364692687988, 'learning_rate': 3.75e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.3699, 'grad_norm': 9.671299934387207, 'learning_rate': 3.625e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.39, 'grad_norm': 9.423399925231934, 'learning_rate': 3.5000000000000004e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.168, 'grad_norm': 10.555978775024414, 'learning_rate': 3.3750000000000003e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.4062, 'grad_norm': 9.081645011901855, 'learning_rate': 3.2500000000000002e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.2328, 'grad_norm': 8.238192558288574, 'learning_rate': 3.125e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.2117, 'grad_norm': 8.344420433044434, 'learning_rate': 3e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.2488, 'grad_norm': 9.779040336608887, 'learning_rate': 2.8750000000000004e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.2828, 'grad_norm': 8.346026420593262, 'learning_rate': 2.7500000000000004e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.4674, 'grad_norm': 8.168132781982422, 'learning_rate': 2.625e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.2598, 'grad_norm': 7.97592830657959, 'learning_rate': 2.5e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.3447, 'grad_norm': 10.082160949707031, 'learning_rate': 2.375e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.2311, 'grad_norm': 8.935636520385742, 'learning_rate': 2.25e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.3922, 'grad_norm': 8.796125411987305, 'learning_rate': 2.1250000000000004e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.315, 'grad_norm': 8.807939529418945, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.2951, 'grad_norm': 8.721334457397461, 'learning_rate': 1.875e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.3289, 'grad_norm': 9.166098594665527, 'learning_rate': 1.7500000000000002e-06, 'epoch': 0.13}\\r\\n\",\n      \"{'loss': 3.46, 'grad_norm': 8.010759353637695, 'learning_rate': 1.6250000000000001e-06, 'epoch': 0.14}\\r\\n\",\n      \"{'loss': 3.4809, 'grad_norm': 8.220529556274414, 'learning_rate': 1.5e-06, 'epoch': 0.14}\\r\\n\",\n      \"{'loss': 3.4166, 'grad_norm': 8.10384750366211, 'learning_rate': 1.3750000000000002e-06, 'epoch': 0.14}\\r\\n\",\n      \"{'loss': 3.458, 'grad_norm': 8.7192964553833, 'learning_rate': 1.25e-06, 'epoch': 0.14}\\r\\n\",\n      \"{'loss': 3.2795, 'grad_norm': 8.834420204162598, 'learning_rate': 1.125e-06, 'epoch': 0.14}\\r\\n\",\n      \"{'loss': 3.3441, 'grad_norm': 9.3894681930542, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.14}\\r\\n\",\n      \"{'loss': 3.3844, 'grad_norm': 7.872992038726807, 'learning_rate': 8.750000000000001e-07, 'epoch': 0.14}\\r\\n\",\n      \"{'loss': 3.5111, 'grad_norm': 8.390124320983887, 'learning_rate': 7.5e-07, 'epoch': 0.14}\\r\\n\",\n      \"{'loss': 3.3422, 'grad_norm': 9.196588516235352, 'learning_rate': 6.25e-07, 'epoch': 0.14}\\r\\n\",\n      \"{'loss': 3.2922, 'grad_norm': 8.946027755737305, 'learning_rate': 5.000000000000001e-07, 'epoch': 0.14}\\r\\n\",\n      \"{'loss': 3.4168, 'grad_norm': 7.884989261627197, 'learning_rate': 3.75e-07, 'epoch': 0.14}\\r\\n\",\n      \"{'loss': 3.4125, 'grad_norm': 9.072811126708984, 'learning_rate': 2.5000000000000004e-07, 'epoch': 0.14}\\r\\n\",\n      \"{'loss': 3.4373, 'grad_norm': 8.543241500854492, 'learning_rate': 1.2500000000000002e-07, 'epoch': 0.14}\\r\\n\",\n      \"{'loss': 3.3844, 'grad_norm': 9.427127838134766, 'learning_rate': 0.0, 'epoch': 0.14}\\r\\n\",\n      \"100%|███████████████████████████████████████| 4000/4000 [40:55<00:00,  1.92it/s]***** Running Evaluation *****\\r\\n\",\n      \"  Num examples = 50\\r\\n\",\n      \"  Batch size = 16\\r\\n\",\n      \"\\r\\n\",\n      \"  0%|                                                     | 0/4 [00:00<?, ?it/s]\\u001B[A\\r\\n\",\n      \" 50%|██████████████████████▌                      | 2/4 [00:03<00:03,  1.96s/it]\\u001B[A\\r\\n\",\n      \" 75%|█████████████████████████████████▊           | 3/4 [00:06<00:02,  2.33s/it]\\u001B[A\\r\\n\",\n      \"                                                                                \\u001B[A\\r\\n\",\n      \"\\u001B[A{'eval_rouge-1': 31.607680000000002, 'eval_rouge-2': 6.832874, 'eval_rouge-l': 25.068815999999998, 'eval_bleu-4': 0.03411200822704291, 'eval_runtime': 12.6342, 'eval_samples_per_second': 3.958, 'eval_steps_per_second': 0.317, 'epoch': 0.14}\\r\\n\",\n      \"100%|███████████████████████████████████████| 4000/4000 [41:08<00:00,  1.92it/s]\\r\\n\",\n      \"100%|█████████████████████████████████████████████| 4/4 [00:09<00:00,  2.33s/it]\\u001B[A\\r\\n\",\n      \"                                                                                \\u001B[ASaving model checkpoint to ./output/checkpoint-4000\\r\\n\",\n      \"/media/zr/Data/Code/ChatGLM3/venv/lib/python3.10/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /media/zr/Data/Models/LLM/chatglm3-6b - will assume that the vocabulary was not modified.\\r\\n\",\n      \"  warnings.warn(\\r\\n\",\n      \"\\r\\n\",\n      \"\\r\\n\",\n      \"Training completed. Do not forget to share your model on huggingface.co/models =)\\r\\n\",\n      \"\\r\\n\",\n      \"\\r\\n\",\n      \"{'train_runtime': 2468.7229, 'train_samples_per_second': 6.481, 'train_steps_per_second': 1.62, 'train_loss': 3.419384765625, 'epoch': 0.14}\\r\\n\",\n      \"100%|███████████████████████████████████████| 4000/4000 [41:08<00:00,  1.62it/s]\\r\\n\",\n      \"***** Running Prediction *****\\r\\n\",\n      \"  Num examples = 1070\\r\\n\",\n      \"  Batch size = 16\\r\\n\",\n      \"100%|███████████████████████████████████████████| 67/67 [12:42<00:00, 11.38s/it]\\r\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"!CUDA_VISIBLE_DEVICES=0 NCCL_P2P_DISABLE=\\\"1\\\" NCCL_IB_DISABLE=\\\"1\\\" python finetune_hf.py  data/AdvertiseGen_fix  /media/zr/Data/Models/LLM/chatglm3-6b  configs/lora.yaml\"\n   ],\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"17c87410a24d844f\",\n    \"outputId\": \"e347fc7d-875e-40c9-c682-3e064100476b\",\n    \"ExecuteTime\": {\n     \"end_time\": \"2024-04-14T06:23:41.282431Z\",\n     \"start_time\": \"2024-04-14T05:29:23.810692Z\"\n    }\n   },\n   \"id\": \"17c87410a24d844f\"\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"source\": [\n    \"## 3. 使用微调的数据集进行推理\\n\",\n    \"在完成微调任务之后，我们可以查看到 `output` 文件夹下多了很多个`checkpoint-*`的文件夹，这些文件夹代表了训练的轮数。\\n\",\n    \"我们选择最后一轮的微调权重，并使用inference进行导入。\"\n   ],\n   \"metadata\": {\n    \"collapsed\": false,\n    \"id\": \"d9418f6c5c264601\"\n   },\n   \"id\": \"d9418f6c5c264601\"\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Loading checkpoint shards: 100%|██████████████████| 7/7 [00:02<00:00,  2.45it/s]\\r\\n\",\n      \"Setting eos_token is not supported, use the default one.\\r\\n\",\n      \"Setting pad_token is not supported, use the default one.\\r\\n\",\n      \"Setting unk_token is not supported, use the default one.\\r\\n\",\n      \"这款连衣裙采用压褶的版型设计，不规则的木耳边拼接，修饰了腰线，使得身材更加修长，不规则的压褶设计，增加了层次感，不规则的压褶，修饰了腰线，拉长腿部比例，显瘦又性感，套头的设计，方便穿脱，不规则的压褶，增加层次感，视觉上拉长腿部比例，百褶的网纱拼接，增加了层次感，整体气质优雅。\\r\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"!CUDA_VISIBLE_DEVICES=0 NCCL_P2P_DISABLE=\\\"1\\\" NCCL_IB_DISABLE=\\\"1\\\" python inference_hf.py output/checkpoint-4000/ --prompt \\\"类型#裙*版型#显瘦*材质#网纱*风格#性感*裙型#百褶*裙下摆#压褶*裙长#连衣裙*裙衣门襟#拉链*裙衣门襟#套头*裙款式#拼接*裙款式#拉链*裙款式#木耳边*裙款式#抽褶*裙款式#不规则\\\"\"\n   ],\n   \"metadata\": {\n    \"id\": \"5060015c24e97ae\",\n    \"outputId\": \"d3f03d0d-46bf-4c74-9b00-dc0160da0e15\",\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"ExecuteTime\": {\n     \"end_time\": \"2024-04-14T06:23:52.725227Z\",\n     \"start_time\": \"2024-04-14T06:23:41.284552Z\"\n    }\n   },\n   \"id\": \"5060015c24e97ae\"\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"source\": [\n    \"## 4. 总结\\n\",\n    \"到此位置，我们就完成了使用单张 GPU Lora 来微调 ChatGLM3-6B 模型，使其能生产出更好的广告。\\n\",\n    \"在本章节中，你将会学会：\\n\",\n    \"+ 如何使用模型进行 Lora 微调\\n\",\n    \"+ 微调数据集的准备和对齐\\n\",\n    \"+ 使用微调的模型进行推理\"\n   ],\n   \"metadata\": {\n    \"collapsed\": false,\n    \"id\": \"18cd83087f096094\"\n   },\n   \"id\": \"18cd83087f096094\"\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"name\": \"python3\",\n   \"language\": \"python\",\n   \"display_name\": \"Python 3 (ipykernel)\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 2\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython2\",\n   \"version\": \"2.7.6\"\n  },\n  \"colab\": {\n   \"provenance\": [],\n   \"machine_shape\": \"hm\",\n   \"gpuType\": \"V100\"\n  },\n  \"accelerator\": \"GPU\"\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "finetune_demo/requirements.txt",
    "content": "jieba>=0.42.1\nruamel_yaml>=0.18.6\nrouge_chinese>=1.0.3\njupyter>=1.0.0\ndatasets>=2.18.0\npeft>=0.10.0\ndeepspeed==0.16.2\nmpi4py>=3.1.5"
  },
  {
    "path": "langchain_demo/ChatGLM3.py",
    "content": "import ast\nimport json\nfrom langchain.llms.base import LLM\nfrom transformers import AutoTokenizer, AutoModel, AutoConfig\nfrom typing import List, Optional\n\n\nclass ChatGLM3(LLM):\n    max_token: int = 8192\n    do_sample: bool = True\n    temperature: float = 0.8\n    top_p = 0.8\n    tokenizer: object = None\n    model: object = None\n    history: List = []\n    has_search: bool = False\n\n    def __init__(self):\n        super().__init__()\n\n    @property\n    def _llm_type(self) -> str:\n        return \"ChatGLM3\"\n\n    def load_model(self, model_name_or_path=None):\n        model_config = AutoConfig.from_pretrained(\n            model_name_or_path,\n            trust_remote_code=True\n        )\n        self.tokenizer = AutoTokenizer.from_pretrained(\n            model_name_or_path,\n            trust_remote_code=True\n        )\n        self.model = AutoModel.from_pretrained(\n            model_name_or_path, config=model_config, trust_remote_code=True, device_map=\"auto\").eval()\n\n    def _tool_history(self, prompt: str):\n        ans = []\n\n        tool_prompts = prompt.split(\n            \"You have access to the following tools:\\n\\n\")[1].split(\"\\n\\nUse a json blob\")[0].split(\"\\n\")\n        tools_json = []\n\n        for tool_desc in tool_prompts:\n            name = tool_desc.split(\":\")[0]\n            description = tool_desc.split(\", args:\")[0].split(\":\")[0].strip()\n            parameters_str = tool_desc.split(\"args:\")[1].strip()\n            parameters_dict = ast.literal_eval(parameters_str)\n            params_cleaned = {}\n            for param, details in parameters_dict.items():\n                params_cleaned[param] = {'description': details['description'], 'type': details['type']}\n\n            tools_json.append({\n                \"name\": name,\n                \"description\": description,\n                \"parameters\": params_cleaned\n            })\n\n        ans.append({\n            \"role\": \"system\",\n            \"content\": \"Answer the following questions as best as you can. You have access to the following tools:\",\n            \"tools\": tools_json\n        })\n\n        dialog_parts = prompt.split(\"Human: \")\n        for part in dialog_parts[1:]:\n            if \"\\nAI: \" in part:\n                user_input, ai_response = part.split(\"\\nAI: \")\n                ai_response = ai_response.split(\"\\n\")[0]\n            else:\n                user_input = part\n                ai_response = None\n\n            ans.append({\"role\": \"user\", \"content\": user_input.strip()})\n            if ai_response:\n                ans.append({\"role\": \"assistant\", \"content\": ai_response.strip()})\n\n        query = dialog_parts[-1].split(\"\\n\")[0]\n        return ans, query\n\n    def _extract_observation(self, prompt: str):\n        return_json = prompt.split(\"Observation: \")[-1].split(\"\\nThought:\")[0]\n        self.history.append({\n            \"role\": \"observation\",\n            \"content\": return_json\n        })\n        return\n\n    def _extract_tool(self):\n        if len(self.history[-1][\"metadata\"]) > 0:\n            metadata = self.history[-1][\"metadata\"]\n            content = self.history[-1][\"content\"]\n\n            lines = content.split('\\n')\n            for line in lines:\n                if 'tool_call(' in line and ')' in line and self.has_search is False:\n                    # 获取括号内的字符串\n                    params_str = line.split('tool_call(')[-1].split(')')[0]\n\n                    # 解析参数对\n                    params_pairs = [param.split(\"=\") for param in params_str.split(\",\") if \"=\" in param]\n                    params = {pair[0].strip(): pair[1].strip().strip(\"'\\\"\") for pair in params_pairs}\n                    action_json = {\n                        \"action\": metadata,\n                        \"action_input\": params\n                    }\n                    self.has_search = True\n                    print(\"*****Action*****\")\n                    print(action_json)\n                    print(\"*****Answer*****\")\n                    return f\"\"\"\nAction: \n```\n{json.dumps(action_json, ensure_ascii=False)}\n```\"\"\"\n        final_answer_json = {\n            \"action\": \"Final Answer\",\n            \"action_input\": self.history[-1][\"content\"]\n        }\n        self.has_search = False\n        return f\"\"\"\nAction: \n```\n{json.dumps(final_answer_json, ensure_ascii=False)}\n```\"\"\"\n\n    def _call(self, prompt: str, history: List = [], stop: Optional[List[str]] = [\"<|user|>\"]):\n        if not self.has_search:\n            self.history, query = self._tool_history(prompt)\n        else:\n            self._extract_observation(prompt)\n            query = \"\"\n        _, self.history = self.model.chat(\n            self.tokenizer,\n            query,\n            history=self.history,\n            do_sample=self.do_sample,\n            max_length=self.max_token,\n            temperature=self.temperature,\n        )\n        response = self._extract_tool()\n        history.append((prompt, response))\n        return response\n"
  },
  {
    "path": "langchain_demo/main.py",
    "content": "\"\"\"\nThis script demonstrates the use of the LangChain's StructuredChatAgent and AgentExecutor alongside various tools\n\nThe script utilizes the ChatGLM3 model, a large language model for understanding and generating human-like text.\nThe model is loaded from a specified path and integrated into the chat agent.\n\nTools:\n- Calculator: Performs arithmetic calculations.\n- Weather: Provides weather-related information based on input queries.\n- DistanceConverter: Converts distances between meters, kilometers, and feet.\n\nThe agent operates in three modes:\n1. Single Parameter without History: Uses Calculator to perform simple arithmetic.\n2. Single Parameter with History: Uses Weather tool to answer queries about temperature, considering the\nconversation history.\n3. Multiple Parameters without History: Uses DistanceConverter to convert distances between specified units.\n4. Single use Langchain Tool: Uses Arxiv tool to search for scientific articles.\n\nNote:\nThe model calling tool fails, which may cause some errors or inability to execute. Try to reduce the temperature\nparameters of the model, or reduce the number of tools, especially the third function.\nThe success rate of multi-parameter calling is low. The following errors may occur:\n\nRequired fields [type=missing, input_value={'distance': '30', 'unit': 'm', 'to': 'km'}, input_type=dict]\n\nThe model illusion in this case generates parameters that do not meet the requirements.\nThe top_p and temperature parameters of the model should be adjusted to better solve such problems.\n\nSuccess example:\n\n*****Action*****\n\n{\n    'action': 'weather',\n    'action_input': {\n        'location': '厦门'\n        }\n}\n\n*****Answer*****\n\n{\n    'input': '厦门比北京热吗?',\n    'chat_history': [HumanMessage(content='北京温度多少度'), AIMessage(content='北京现在12度')],\n    'output': '根据最新的天气数据，厦门今天的气温为18度，天气晴朗。而北京今天的气温为12度。所以，厦门比北京热。'\n}\n\n****************\n\n\"\"\"\n\nimport os\n\nfrom langchain import hub\nfrom langchain.agents import AgentExecutor, create_structured_chat_agent, load_tools\nfrom langchain_core.messages import AIMessage, HumanMessage\n\nfrom ChatGLM3 import ChatGLM3\nfrom tools.Calculator import Calculator\nfrom tools.Weather import Weather\nfrom tools.DistanceConversion import DistanceConverter\n\nMODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')\n\nif __name__ == \"__main__\":\n    llm = ChatGLM3()\n    llm.load_model(MODEL_PATH)\n    prompt = hub.pull(\"hwchase17/structured-chat-agent\")\n\n    # for single parameter without history\n\n    tools = [Calculator()]\n    agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)\n    agent_executor = AgentExecutor(agent=agent, tools=tools)\n    ans = agent_executor.invoke({\"input\": \"34 * 34\"})\n    print(ans)\n\n    # for singe parameter with history\n\n    tools = [Weather()]\n    agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)\n    agent_executor = AgentExecutor(agent=agent, tools=tools)\n    ans = agent_executor.invoke(\n        {\n            \"input\": \"厦门比北京热吗?\",\n            \"chat_history\": [\n                HumanMessage(content=\"北京温度多少度\"),\n                AIMessage(content=\"北京现在12度\"),\n            ],\n        }\n    )\n    print(ans)\n\n    # for multiple parameters without history\n\n    tools = [DistanceConverter()]\n    agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)\n    agent_executor = AgentExecutor(agent=agent, tools=tools)\n    ans = agent_executor.invoke({\"input\": \"how many meters in 30 km?\"})\n\n    print(ans)\n\n    # for using langchain tools\n\n    tools = load_tools([\"arxiv\"], llm=llm)\n    agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)\n    agent_executor = AgentExecutor(agent=agent, tools=tools)\n    ans = agent_executor.invoke({\"input\": \"Describe the paper about GLM 130B\"})\n\n    print(ans)\n"
  },
  {
    "path": "langchain_demo/tools/Calculator.py",
    "content": "import abc\nimport re\nfrom typing import Type\nfrom langchain.tools import BaseTool\nfrom pydantic import BaseModel, Field\n\n\nclass CalculatorInput(BaseModel):\n    calculation: str = Field(description=\"calculation to perform\")\n\n\nclass Calculator(BaseTool, abc.ABC):\n    name = \"Calculator\"\n    description = \"Useful for when you need to calculate math problems\"\n    args_schema: Type[BaseModel] = CalculatorInput\n\n    def __init__(self):\n        super().__init__()\n\n    def parameter_validation(self, para: str):\n        \"\"\"\n        You can write your own parameter validation rules here,\n        you can refer to the code given here.\n        :param para:\n        :return:\n        \"\"\"\n        symbols = [\"math\", \"sqrt\", \"log\", \"sin\", \"cos\", \"tan\", \"pi\"]\n        for sym in symbols:\n            para = para.replace(sym, \"\")\n        patten = re.compile(\"[+*/\\-%\\d()=\\s.]{3,}\")\n        if re.findall(patten, para):\n            return True\n\n    def _run(self, calculation: str) -> str:\n        calculation = calculation.replace(\"^\", \"**\")\n        if \"sqrt\" in calculation and \"math\" not in calculation:\n            calculation = calculation.replace(\"sqrt\", \"math.sqrt\")\n        if \"log\" in calculation and \"math\" not in calculation:\n            calculation = calculation.replace(\"log\", \"math.log\")\n        if \"sin\" in calculation and \"math\" not in calculation:\n            calculation = calculation.replace(\"sin\", \"math.sin\")\n        if \"cos\" in calculation and \"math\" not in calculation:\n            calculation = calculation.replace(\"cos\", \"math.cos\")\n        if \"tan\" in calculation and \"math\" not in calculation:\n            calculation = calculation.replace(\"tan\", \"math.tan\")\n        if \"pi\" in calculation and \"math\" not in calculation:\n            calculation = calculation.replace(\"pi\", \"math.pi\")\n        if \"pI\" in calculation and \"math\" not in calculation:\n            calculation = calculation.replace(\"pI\", \"math.pi\")\n        if \"PI\" in calculation and \"math\" not in calculation:\n            calculation = calculation.replace(\"PI\", \"math.pi\")\n        if \"Pi\" in calculation and \"math\" not in calculation:\n            calculation = calculation.replace(\"Pi\", \"math.pi\")\n        return eval(calculation)\n"
  },
  {
    "path": "langchain_demo/tools/DistanceConversion.py",
    "content": "import abc\nfrom typing import Type\nfrom langchain.tools import BaseTool\nfrom pydantic import BaseModel, Field\n\n\nclass DistanceConversionInput(BaseModel):\n    distance: float = Field(description=\"The numerical value of the distance to convert\")\n    unit: str = Field(description=\"The current unit of the distance (m, km, or feet)\")\n    to_unit: str = Field(description=\"The target unit to convert the distance into (m, km, or feet)\")\n\n\nclass DistanceConverter(BaseTool, abc.ABC):\n    name = \"DistanceConverter\"\n    description = \"Converts distance between meters, kilometers, and feet\"\n    args_schema: Type[BaseModel] = DistanceConversionInput\n\n    def __init__(self):\n        super().__init__()\n\n    def _run(self, distance: float, unit: str, to_unit: str) -> str:\n        unit_conversions = {\n            \"m_to_km\": 0.001,\n            \"km_to_m\": 1000,\n            \"feet_to_m\": 0.3048,\n            \"m_to_feet\": 3.28084,\n            \"km_to_feet\": 3280.84,\n            \"feet_to_km\": 0.0003048\n        }\n\n        if unit == to_unit:\n            return f\"{distance} {unit} is equal to {distance} {to_unit}\"\n\n        if unit == \"km\":\n            distance *= unit_conversions[\"km_to_m\"]\n        elif unit == \"feet\":\n            distance *= unit_conversions[\"feet_to_m\"]\n\n        if to_unit == \"km\":\n            converted_distance = distance * unit_conversions[\"m_to_km\"]\n        elif to_unit == \"feet\":\n            converted_distance = distance * unit_conversions[\"m_to_feet\"]\n        else:\n            converted_distance = distance  # already in meters if this block is reached\n\n        return f\"{distance} {unit} is equal to {converted_distance} {to_unit}\"\n"
  },
  {
    "path": "langchain_demo/tools/Weather.py",
    "content": "import os\nimport requests\n\nfrom typing import Type, Any\nfrom langchain.tools import BaseTool\nfrom pydantic import BaseModel, Field\n\nclass WeatherInput(BaseModel):\n    location: str = Field(description=\"the location need to check the weather\")\n\n\nclass Weather(BaseTool):\n    name = \"weather\"\n    description = \"Use for searching weather at a specific location\"\n    args_schema: Type[BaseModel] = WeatherInput\n\n    def __init__(self):\n        super().__init__()\n\n    def _run(self, location: str) -> dict[str, Any]:\n        api_key = os.environ[\"SENIVERSE_KEY\"]\n        url = f\"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c\"\n        response = requests.get(url)\n        if response.status_code == 200:\n            data = response.json()\n            weather = {\n                \"temperature\": data[\"results\"][0][\"now\"][\"temperature\"],\n                \"description\": data[\"results\"][0][\"now\"][\"text\"],\n            }\n            return weather\n        else:\n            raise Exception(\n                f\"Failed to retrieve weather: {response.status_code}\")\n"
  },
  {
    "path": "openai_api_demo/api_server.py",
    "content": "\"\"\"\nThis script implements an API for the ChatGLM3-6B model,\nformatted similarly to OpenAI's API (https://platform.openai.com/docs/api-reference/chat).\nIt's designed to be run as a web server using FastAPI and uvicorn,\nmaking the ChatGLM3-6B model accessible through OpenAI Client.\n\nKey Components and Features:\n- Model and Tokenizer Setup: Configures the model and tokenizer paths and loads them.\n- FastAPI Configuration: Sets up a FastAPI application with CORS middleware for handling cross-origin requests.\n- API Endpoints:\n  - \"/v1/models\": Lists the available models, specifically ChatGLM3-6B.\n  - \"/v1/chat/completions\": Processes chat completion requests with options for streaming and regular responses.\n  - \"/v1/embeddings\": Processes Embedding request of a list of text inputs.\n- Token Limit Caution: In the OpenAI API, 'max_tokens' is equivalent to HuggingFace's 'max_new_tokens', not 'max_length'.\nFor instance, setting 'max_tokens' to 8192 for a 6b model would result in an error due to the model's inability to output\nthat many tokens after accounting for the history and prompt tokens.\n- Stream Handling and Custom Functions: Manages streaming responses and custom function calls within chat responses.\n- Pydantic Models: Defines structured models for requests and responses, enhancing API documentation and type safety.\n- Main Execution: Initializes the model and tokenizer, and starts the FastAPI app on the designated host and port.\n\nNote:\n    This script doesn't include the setup for special tokens or multi-GPU support by default.\n    Users need to configure their special tokens and can enable multi-GPU support as per the provided instructions.\n    Embedding Models only support in One GPU.\n\n    Running this script requires 14-15GB of GPU memory. 2 GB for the embedding model and 12-13 GB for the FP16 ChatGLM3 LLM.\n\n\n\"\"\"\n\nimport os\nimport time\nimport tiktoken\nimport torch\nimport uvicorn\n\nfrom fastapi import FastAPI, HTTPException, Response\nfrom fastapi.middleware.cors import CORSMiddleware\n\nfrom contextlib import asynccontextmanager\nfrom typing import List, Literal, Optional, Union\nfrom loguru import logger\nfrom pydantic import BaseModel, Field\nfrom transformers import AutoTokenizer, AutoModel\nfrom utils import process_response, generate_chatglm3, generate_stream_chatglm3\nfrom sentence_transformers import SentenceTransformer\n\nfrom sse_starlette.sse import EventSourceResponse\n\n# Set up limit request time\nEventSourceResponse.DEFAULT_PING_INTERVAL = 1000\n\n# set LLM path\nMODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')\nTOKENIZER_PATH = os.environ.get(\"TOKENIZER_PATH\", MODEL_PATH)\n\n# set Embedding Model path\nEMBEDDING_PATH = os.environ.get('EMBEDDING_PATH', 'BAAI/bge-m3')\n\n\n@asynccontextmanager\nasync def lifespan(app: FastAPI):\n    yield\n    if torch.cuda.is_available():\n        torch.cuda.empty_cache()\n        torch.cuda.ipc_collect()\n\n\napp = FastAPI(lifespan=lifespan)\n\napp.add_middleware(\n    CORSMiddleware,\n    allow_origins=[\"*\"],\n    allow_credentials=True,\n    allow_methods=[\"*\"],\n    allow_headers=[\"*\"],\n)\n\n\nclass ModelCard(BaseModel):\n    id: str\n    object: str = \"model\"\n    created: int = Field(default_factory=lambda: int(time.time()))\n    owned_by: str = \"owner\"\n    root: Optional[str] = None\n    parent: Optional[str] = None\n    permission: Optional[list] = None\n\n\nclass ModelList(BaseModel):\n    object: str = \"list\"\n    data: List[ModelCard] = []\n\n\nclass FunctionCallResponse(BaseModel):\n    name: Optional[str] = None\n    arguments: Optional[str] = None\n\n\nclass ChatMessage(BaseModel):\n    role: Literal[\"user\", \"assistant\", \"system\", \"function\"]\n    content: str = None\n    name: Optional[str] = None\n    function_call: Optional[FunctionCallResponse] = None\n\n\nclass DeltaMessage(BaseModel):\n    role: Optional[Literal[\"user\", \"assistant\", \"system\"]] = None\n    content: Optional[str] = None\n    function_call: Optional[FunctionCallResponse] = None\n\n\n## for Embedding\nclass EmbeddingRequest(BaseModel):\n    input: Union[List[str], str]\n    model: str\n\n\nclass CompletionUsage(BaseModel):\n    prompt_tokens: int\n    completion_tokens: int\n    total_tokens: int\n\n\nclass EmbeddingResponse(BaseModel):\n    data: list\n    model: str\n    object: str\n    usage: CompletionUsage\n\n\n# for ChatCompletionRequest\n\nclass UsageInfo(BaseModel):\n    prompt_tokens: int = 0\n    total_tokens: int = 0\n    completion_tokens: Optional[int] = 0\n\n\nclass ChatCompletionRequest(BaseModel):\n    model: str\n    messages: List[ChatMessage]\n    temperature: Optional[float] = 0.8\n    top_p: Optional[float] = 0.8\n    max_tokens: Optional[int] = None\n    stream: Optional[bool] = False\n    tools: Optional[Union[dict, List[dict]]] = None\n    repetition_penalty: Optional[float] = 1.1\n\n\nclass ChatCompletionResponseChoice(BaseModel):\n    index: int\n    message: ChatMessage\n    finish_reason: Literal[\"stop\", \"length\", \"function_call\"]\n\n\nclass ChatCompletionResponseStreamChoice(BaseModel):\n    delta: DeltaMessage\n    finish_reason: Optional[Literal[\"stop\", \"length\", \"function_call\"]]\n    index: int\n\n\nclass ChatCompletionResponse(BaseModel):\n    model: str\n    id: str\n    object: Literal[\"chat.completion\", \"chat.completion.chunk\"]\n    choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]\n    created: Optional[int] = Field(default_factory=lambda: int(time.time()))\n    usage: Optional[UsageInfo] = None\n\n\n@app.get(\"/health\")\nasync def health() -> Response:\n    \"\"\"Health check.\"\"\"\n    return Response(status_code=200)\n\n\n@app.post(\"/v1/embeddings\", response_model=EmbeddingResponse)\nasync def get_embeddings(request: EmbeddingRequest):\n    if isinstance(request.input, str):\n        embeddings = [embedding_model.encode(request.input)]\n    else:\n        embeddings = [embedding_model.encode(text) for text in request.input]\n    embeddings = [embedding.tolist() for embedding in embeddings]\n\n    def num_tokens_from_string(string: str) -> int:\n        \"\"\"\n        Returns the number of tokens in a text string.\n        use cl100k_base tokenizer\n        \"\"\"\n        encoding = tiktoken.get_encoding('cl100k_base')\n        num_tokens = len(encoding.encode(string))\n        return num_tokens\n\n    response = {\n        \"data\": [\n            {\n                \"object\": \"embedding\",\n                \"embedding\": embedding,\n                \"index\": index\n            }\n            for index, embedding in enumerate(embeddings)\n        ],\n        \"model\": request.model,\n        \"object\": \"list\",\n        \"usage\": CompletionUsage(\n            prompt_tokens=sum(len(text.split()) for text in request.input),\n            completion_tokens=0,\n            total_tokens=sum(num_tokens_from_string(text) for text in request.input),\n        )\n    }\n    return response\n\n\n@app.get(\"/v1/models\", response_model=ModelList)\nasync def list_models():\n    model_card = ModelCard(\n        id=\"chatglm3-6b\"\n    )\n    return ModelList(\n        data=[model_card]\n    )\n\n\n@app.post(\"/v1/chat/completions\", response_model=ChatCompletionResponse)\nasync def create_chat_completion(request: ChatCompletionRequest):\n    global model, tokenizer\n\n    if len(request.messages) < 1 or request.messages[-1].role == \"assistant\":\n        raise HTTPException(status_code=400, detail=\"Invalid request\")\n\n    gen_params = dict(\n        messages=request.messages,\n        temperature=request.temperature,\n        top_p=request.top_p,\n        max_tokens=request.max_tokens or 1024,\n        echo=False,\n        stream=request.stream,\n        repetition_penalty=request.repetition_penalty,\n        tools=request.tools,\n    )\n    logger.debug(f\"==== request ====\\n{gen_params}\")\n\n    if request.stream:\n\n        # Use the stream mode to read the first few characters, if it is not a function call, direct stram output\n        predict_stream_generator = predict_stream(request.model, gen_params)\n        output = next(predict_stream_generator)\n        if not contains_custom_function(output):\n            return EventSourceResponse(predict_stream_generator, media_type=\"text/event-stream\")\n\n        # Obtain the result directly at one time and determine whether tools needs to be called.\n        logger.debug(f\"First result output：\\n{output}\")\n\n        function_call = None\n        if output and request.tools:\n            try:\n                function_call = process_response(output, use_tool=True)\n            except:\n                logger.warning(\"Failed to parse tool call\")\n\n        # CallFunction\n        if isinstance(function_call, dict):\n            function_call = FunctionCallResponse(**function_call)\n\n            \"\"\"\n            In this demo, we did not register any tools.\n            You can use the tools that have been implemented in our `tools_using_demo` and implement your own streaming tool implementation here.\n            Similar to the following method:\n                function_args = json.loads(function_call.arguments)\n                tool_response = dispatch_tool(tool_name: str, tool_params: dict)\n            \"\"\"\n            tool_response = \"\"\n\n            if not gen_params.get(\"messages\"):\n                gen_params[\"messages\"] = []\n\n            gen_params[\"messages\"].append(ChatMessage(\n                role=\"assistant\",\n                content=output,\n            ))\n            gen_params[\"messages\"].append(ChatMessage(\n                role=\"function\",\n                name=function_call.name,\n                content=tool_response,\n            ))\n\n            # Streaming output of results after function calls\n            generate = predict(request.model, gen_params)\n            return EventSourceResponse(generate, media_type=\"text/event-stream\")\n\n        else:\n            # Handled to avoid exceptions in the above parsing function process.\n            generate = parse_output_text(request.model, output)\n            return EventSourceResponse(generate, media_type=\"text/event-stream\")\n\n    # Here is the handling of stream = False\n    response = generate_chatglm3(model, tokenizer, gen_params)\n\n    # Remove the first newline character\n    if response[\"text\"].startswith(\"\\n\"):\n        response[\"text\"] = response[\"text\"][1:]\n    response[\"text\"] = response[\"text\"].strip()\n\n    usage = UsageInfo()\n    function_call, finish_reason = None, \"stop\"\n    if request.tools:\n        try:\n            function_call = process_response(response[\"text\"], use_tool=True)\n        except:\n            logger.warning(\"Failed to parse tool call, maybe the response is not a tool call or have been answered.\")\n\n    if isinstance(function_call, dict):\n        finish_reason = \"function_call\"\n        function_call = FunctionCallResponse(**function_call)\n\n    message = ChatMessage(\n        role=\"assistant\",\n        content=response[\"text\"],\n        function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,\n    )\n\n    logger.debug(f\"==== message ====\\n{message}\")\n\n    choice_data = ChatCompletionResponseChoice(\n        index=0,\n        message=message,\n        finish_reason=finish_reason,\n    )\n    task_usage = UsageInfo.model_validate(response[\"usage\"])\n    for usage_key, usage_value in task_usage.model_dump().items():\n        setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)\n\n    return ChatCompletionResponse(\n        model=request.model,\n        id=\"\",  # for open_source model, id is empty\n        choices=[choice_data],\n        object=\"chat.completion\",\n        usage=usage\n    )\n\n\nasync def predict(model_id: str, params: dict):\n    global model, tokenizer\n\n    choice_data = ChatCompletionResponseStreamChoice(\n        index=0,\n        delta=DeltaMessage(role=\"assistant\"),\n        finish_reason=None\n    )\n    chunk = ChatCompletionResponse(model=model_id, id=\"\", choices=[choice_data], object=\"chat.completion.chunk\")\n    yield \"{}\".format(chunk.model_dump_json(exclude_unset=True))\n\n    previous_text = \"\"\n    for new_response in generate_stream_chatglm3(model, tokenizer, params):\n        decoded_unicode = new_response[\"text\"]\n        delta_text = decoded_unicode[len(previous_text):]\n        previous_text = decoded_unicode\n\n        finish_reason = new_response[\"finish_reason\"]\n        if len(delta_text) == 0 and finish_reason != \"function_call\":\n            continue\n\n        function_call = None\n        if finish_reason == \"function_call\":\n            try:\n                function_call = process_response(decoded_unicode, use_tool=True)\n            except:\n                logger.warning(\n                    \"Failed to parse tool call, maybe the response is not a tool call or have been answered.\")\n\n        if isinstance(function_call, dict):\n            function_call = FunctionCallResponse(**function_call)\n\n        delta = DeltaMessage(\n            content=delta_text,\n            role=\"assistant\",\n            function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,\n        )\n\n        choice_data = ChatCompletionResponseStreamChoice(\n            index=0,\n            delta=delta,\n            finish_reason=finish_reason\n        )\n        chunk = ChatCompletionResponse(\n            model=model_id,\n            id=\"\",\n            choices=[choice_data],\n            object=\"chat.completion.chunk\"\n        )\n        yield \"{}\".format(chunk.model_dump_json(exclude_unset=True))\n\n    choice_data = ChatCompletionResponseStreamChoice(\n        index=0,\n        delta=DeltaMessage(),\n        finish_reason=\"stop\"\n    )\n    chunk = ChatCompletionResponse(\n        model=model_id,\n        id=\"\",\n        choices=[choice_data],\n        object=\"chat.completion.chunk\"\n    )\n    yield \"{}\".format(chunk.model_dump_json(exclude_unset=True))\n    yield '[DONE]'\n\n\ndef predict_stream(model_id, gen_params):\n    \"\"\"\n    The function call is compatible with stream mode output.\n\n    The first seven characters are determined.\n    If not a function call, the stream output is directly generated.\n    Otherwise, the complete character content of the function call is returned.\n\n    :param model_id:\n    :param gen_params:\n    :return:\n    \"\"\"\n    output = \"\"\n    is_function_call = False\n    has_send_first_chunk = False\n    for new_response in generate_stream_chatglm3(model, tokenizer, gen_params):\n        decoded_unicode = new_response[\"text\"]\n        delta_text = decoded_unicode[len(output):]\n        output = decoded_unicode\n\n        # When it is not a function call and the character length is> 7,\n        # try to judge whether it is a function call according to the special function prefix\n        if not is_function_call and len(output) > 7:\n\n            # Determine whether a function is called\n            is_function_call = contains_custom_function(output)\n            if is_function_call:\n                continue\n\n            # Non-function call, direct stream output\n            finish_reason = new_response[\"finish_reason\"]\n\n            # Send an empty string first to avoid truncation by subsequent next() operations.\n            if not has_send_first_chunk:\n                message = DeltaMessage(\n                    content=\"\",\n                    role=\"assistant\",\n                    function_call=None,\n                )\n                choice_data = ChatCompletionResponseStreamChoice(\n                    index=0,\n                    delta=message,\n                    finish_reason=finish_reason\n                )\n                chunk = ChatCompletionResponse(\n                    model=model_id,\n                    id=\"\",\n                    choices=[choice_data],\n                    created=int(time.time()),\n                    object=\"chat.completion.chunk\"\n                )\n                yield \"{}\".format(chunk.model_dump_json(exclude_unset=True))\n\n            send_msg = delta_text if has_send_first_chunk else output\n            has_send_first_chunk = True\n            message = DeltaMessage(\n                content=send_msg,\n                role=\"assistant\",\n                function_call=None,\n            )\n            choice_data = ChatCompletionResponseStreamChoice(\n                index=0,\n                delta=message,\n                finish_reason=finish_reason\n            )\n            chunk = ChatCompletionResponse(\n                model=model_id,\n                id=\"\",\n                choices=[choice_data],\n                created=int(time.time()),\n                object=\"chat.completion.chunk\"\n            )\n            yield \"{}\".format(chunk.model_dump_json(exclude_unset=True))\n\n    if is_function_call:\n        yield output\n    else:\n        yield '[DONE]'\n\n\nasync def parse_output_text(model_id: str, value: str):\n    \"\"\"\n    Directly output the text content of value\n\n    :param model_id:\n    :param value:\n    :return:\n    \"\"\"\n    choice_data = ChatCompletionResponseStreamChoice(\n        index=0,\n        delta=DeltaMessage(role=\"assistant\", content=value),\n        finish_reason=None\n    )\n    chunk = ChatCompletionResponse(model=model_id, id=\"\", choices=[choice_data], object=\"chat.completion.chunk\")\n    yield \"{}\".format(chunk.model_dump_json(exclude_unset=True))\n\n    choice_data = ChatCompletionResponseStreamChoice(\n        index=0,\n        delta=DeltaMessage(),\n        finish_reason=\"stop\"\n    )\n    chunk = ChatCompletionResponse(model=model_id, id=\"\", choices=[choice_data], object=\"chat.completion.chunk\")\n    yield \"{}\".format(chunk.model_dump_json(exclude_unset=True))\n    yield '[DONE]'\n\n\ndef contains_custom_function(value: str) -> bool:\n    \"\"\"\n    Determine whether 'function_call' according to a special function prefix.\n\n    For example, the functions defined in \"tools_using_demo/tool_register.py\" are all \"get_xxx\" and start with \"get_\"\n\n    [Note] This is not a rigorous judgment method, only for reference.\n\n    :param value:\n    :return:\n    \"\"\"\n    return value and 'get_' in value\n\n\nif __name__ == \"__main__\":\n    # Load LLM\n    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)\n    model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map=\"auto\").eval()\n\n    # load Embedding\n    embedding_model = SentenceTransformer(EMBEDDING_PATH, device=\"cuda\")\n    uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)"
  },
  {
    "path": "openai_api_demo/docker-compose.yml",
    "content": "version: \"3.6\"\n\nservices:\n        glm3_api:\n                image: python:3.10.13-slim\n                restart: unless-stopped\n                working_dir: /glm3\n                container_name: glm3_api\n                env_file: ./.env\n                networks:\n                        - v_glm3\n                deploy:\n                        resources:\n                                reservations:\n                                        devices:\n                                                - driver: nvidia\n                                                  count: 1\n                                                  capabilities: [gpu]\n                environment:\n                        - MODEL_PATH=/models/chatglm3-6b\n                        - EMBEDDING_PATH=/models/bge-large-zh-v1.5\n                        - TZ=Asia/Shanghai\n                        - PYTHONDONTWRITEBYTECODE=1\n                        - PYTHONUNBUFFERED=1\n                        - DOCKER=True\n                ports:\n                        - 8100:8000\n                volumes:\n                        - ./:/glm3\n                        - ${LOCAL_MODEL_PATH}:/models/chatglm3-6b\n                        - ${LOCAL_EMBEDDING_MODEL_PATH}:/models/bge-large-zh-v1.5\n                command:\n                        - sh\n                        - -c\n                        - |\n                          sed -i s/deb.debian.org/mirrors.tencentyun.com/g /etc/apt/sources.list\n                          sed -i s/security.debian.org/mirrors.tencentyun.com/g /etc/apt/sources.list\n                          apt-get update\n                          python -m pip install -i  https://mirror.sjtu.edu.cn/pypi/web/simple --upgrade pip\n                          pip install -r requirements.txt -i  https://pypi.tuna.tsinghua.edu.cn/simple\n                          python api_server.py\nnetworks:\n        v_glm3:\n                driver: bridge"
  },
  {
    "path": "openai_api_demo/langchain_openai_api.py",
    "content": "\"\"\"\nThis script is designed for interacting with a local GLM3 AI model using the `ChatGLM3` class\nfrom the `langchain_community` library. It facilitates continuous dialogue with the GLM3 model.\n\n1. Start the Local Model Service: Before running this script, you need to execute the `api_server.py` script\nto start the GLM3 model's service.\n2. Run the Script: The script includes functionality for initializing the LLMChain object and obtaining AI responses,\nallowing the user to input questions and receive AI answers.\n3. This demo is not support for streaming.\n\n\"\"\"\nfrom langchain.schema.messages import HumanMessage, SystemMessage, AIMessage\nfrom langchain_community.llms.chatglm3 import ChatGLM3\n\n\ndef get_ai_response(messages, user_input):\n    endpoint_url = \"http://127.0.0.1:8000/v1/chat/completions\"\n    llm = ChatGLM3(\n        endpoint_url=endpoint_url,\n        max_tokens=4096,\n        prefix_messages=messages,\n        top_p=0.9\n    )\n    ai_response = llm.invoke(user_input)\n    return ai_response\n\n\ndef continuous_conversation():\n    messages = [\n        SystemMessage(content=\"You are an intelligent AI assistant, named ChatGLM3.\"),\n    ]\n    while True:\n        user_input = input(\"Human (or 'exit' to quit): \")\n        if user_input.lower() == 'exit':\n            break\n        ai_response = get_ai_response(messages, user_input)\n        print(\"ChatGLM3: \", ai_response)\n        messages += [\n            HumanMessage(content=user_input),\n            AIMessage(content=ai_response),\n        ]\n\n\nif __name__ == \"__main__\":\n    continuous_conversation()\n"
  },
  {
    "path": "openai_api_demo/openai_api_request.py",
    "content": "\"\"\"\nThis script is an example of using the OpenAI API to create various interactions with a ChatGLM3 model.\nIt includes functions to:\n\n1. Conduct a basic chat session, asking about weather conditions in multiple cities.\n2. Initiate a simple chat in Chinese, asking the model to tell a short story.\n3. Retrieve and print embeddings for a given text input.\n\nEach function demonstrates a different aspect of the API's capabilities, showcasing how to make requests\nand handle responses.\n\"\"\"\n\nfrom openai import OpenAI\n\nbase_url = \"http://127.0.0.1:8000/v1/\"\nclient = OpenAI(api_key=\"EMPTY\", base_url=base_url)\n\n\ndef function_chat():\n    messages = [{\"role\": \"user\", \"content\": \"What's the weather like in San Francisco, Tokyo, and Paris?\"}]\n    tools = [\n        {\n            \"type\": \"function\",\n            \"function\": {\n                \"name\": \"get_current_weather\",\n                \"description\": \"Get the current weather in a given location\",\n                \"parameters\": {\n                    \"type\": \"object\",\n                    \"properties\": {\n                        \"location\": {\n                            \"type\": \"string\",\n                            \"description\": \"The city and state, e.g. San Francisco, CA\",\n                        },\n                        \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n                    },\n                    \"required\": [\"location\"],\n                },\n            },\n        }\n    ]\n\n    response = client.chat.completions.create(\n        model=\"chatglm3-6b\",\n        messages=messages,\n        tools=tools,\n        tool_choice=\"auto\",\n    )\n    if response:\n        content = response.choices[0].message.content\n        print(content)\n    else:\n        print(\"Error:\", response.status_code)\n\n\ndef simple_chat(use_stream=True):\n    messages = [\n        {\n            \"role\": \"system\",\n            \"content\": \"You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's \"\n                       \"instructions carefully. Respond using markdown.\",\n        },\n        {\n            \"role\": \"user\",\n            \"content\": \"你好，请你用生动的话语给我讲一个小故事吧\"\n        }\n    ]\n    response = client.chat.completions.create(\n        model=\"chatglm3-6b\",\n        messages=messages,\n        stream=use_stream,\n        max_tokens=256,\n        temperature=0.8,\n        presence_penalty=1.1,\n        top_p=0.8)\n    if response:\n        if use_stream:\n            for chunk in response:\n                print(chunk.choices[0].delta.content)\n        else:\n            content = response.choices[0].message.content\n            print(content)\n    else:\n        print(\"Error:\", response.status_code)\n\n\ndef embedding():\n    response = client.embeddings.create(\n        model=\"bge-large-zh-1.5\",\n        input=[\"你好，给我讲一个故事，大概100字\"],\n    )\n    embeddings = response.data[0].embedding\n    print(\"嵌入完成，维度：\", len(embeddings))\n\n\nif __name__ == \"__main__\":\n    simple_chat(use_stream=False)\n    simple_chat(use_stream=True)\n    embedding()\n    function_chat()\n"
  },
  {
    "path": "openai_api_demo/utils.py",
    "content": "import gc\nimport json\nimport torch\nfrom transformers import PreTrainedModel, PreTrainedTokenizer\nfrom transformers.generation.logits_process import LogitsProcessor\nfrom typing import Union, Tuple\n\n\nclass InvalidScoreLogitsProcessor(LogitsProcessor):\n    def __call__(\n            self, input_ids: torch.LongTensor, scores: torch.FloatTensor\n    ) -> torch.FloatTensor:\n        if torch.isnan(scores).any() or torch.isinf(scores).any():\n            scores.zero_()\n            scores[..., 5] = 5e4\n        return scores\n\n\ndef process_response(output: str, use_tool: bool = False) -> Union[str, dict]:\n    content = \"\"\n    for response in output.split(\"<|assistant|>\"):\n        metadata, content = response.split(\"\\n\", maxsplit=1)\n        if not metadata.strip():\n            content = content.strip()\n            content = content.replace(\"[[训练时间]]\", \"2023年\")\n        else:\n            if use_tool:\n                content = \"\\n\".join(content.split(\"\\n\")[1:-1])\n                def tool_call(**kwargs):\n                    return kwargs\n\n                parameters = eval(content)\n                content = {\n                    \"name\": metadata.strip(),\n                    \"arguments\": json.dumps(parameters, ensure_ascii=False)\n                }\n            else:\n                content = {\n                    \"name\": metadata.strip(),\n                    \"content\": content\n                }\n    return content\n\n\n@torch.inference_mode()\ndef generate_stream_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict):\n    messages = params[\"messages\"]\n    tools = params[\"tools\"]\n    temperature = float(params.get(\"temperature\", 1.0))\n    repetition_penalty = float(params.get(\"repetition_penalty\", 1.0))\n    top_p = float(params.get(\"top_p\", 1.0))\n    max_new_tokens = int(params.get(\"max_tokens\", 256))\n    echo = params.get(\"echo\", True)\n    messages = process_chatglm_messages(messages, tools=tools)\n    query, role = messages[-1][\"content\"], messages[-1][\"role\"]\n\n    inputs = tokenizer.build_chat_input(query, history=messages[:-1], role=role)\n    inputs = inputs.to(model.device)\n    input_echo_len = len(inputs[\"input_ids\"][0])\n\n    if input_echo_len >= model.config.seq_length:\n        print(f\"Input length larger than {model.config.seq_length}\")\n\n    eos_token_id = [\n        tokenizer.eos_token_id,\n        tokenizer.get_command(\"<|user|>\"),\n        tokenizer.get_command(\"<|observation|>\")\n    ]\n\n    gen_kwargs = {\n        \"max_new_tokens\": max_new_tokens,\n        \"do_sample\": True if temperature > 1e-5 else False,\n        \"top_p\": top_p,\n        \"repetition_penalty\": repetition_penalty,\n        \"logits_processor\": [InvalidScoreLogitsProcessor()],\n    }\n    if temperature > 1e-5:\n        gen_kwargs[\"temperature\"] = temperature\n\n    total_len = 0\n    for total_ids in model.stream_generate(**inputs, eos_token_id=eos_token_id, **gen_kwargs):\n        total_ids = total_ids.tolist()[0]\n        total_len = len(total_ids)\n        if echo:\n            output_ids = total_ids[:-1]\n        else:\n            output_ids = total_ids[input_echo_len:-1]\n\n        response = tokenizer.decode(output_ids)\n        if response and response[-1] != \"�\":\n            response, stop_found = apply_stopping_strings(response, [\"<|observation|>\"])\n\n            yield {\n                \"text\": response,\n                \"usage\": {\n                    \"prompt_tokens\": input_echo_len,\n                    \"completion_tokens\": total_len - input_echo_len,\n                    \"total_tokens\": total_len,\n                },\n                \"finish_reason\": \"function_call\" if stop_found else None,\n            }\n\n            if stop_found:\n                break\n\n    # Only last stream result contains finish_reason, we set finish_reason as stop\n    ret = {\n        \"text\": response,\n        \"usage\": {\n            \"prompt_tokens\": input_echo_len,\n            \"completion_tokens\": total_len - input_echo_len,\n            \"total_tokens\": total_len,\n        },\n        \"finish_reason\": \"stop\",\n    }\n    yield ret\n\n    gc.collect()\n    torch.cuda.empty_cache()\n\n\ndef process_chatglm_messages(messages, tools=None):\n    _messages = messages\n    messages = []\n    msg_has_sys = False\n    if tools:\n        messages.append(\n            {\n                \"role\": \"system\",\n                \"content\": \"Answer the following questions as best as you can. You have access to the following tools:\",\n                \"tools\": tools\n            }\n        )\n        msg_has_sys = True\n\n    for m in _messages:\n        role, content, func_call = m.role, m.content, m.function_call\n        if role == \"function\":\n            messages.append(\n                {\n                    \"role\": \"observation\",\n                    \"content\": content\n                }\n            )\n\n        elif role == \"assistant\" and func_call is not None:\n            for response in content.split(\"<|assistant|>\"):\n                metadata, sub_content = response.split(\"\\n\", maxsplit=1)\n                messages.append(\n                    {\n                        \"role\": role,\n                        \"metadata\": metadata,\n                        \"content\": sub_content.strip()\n                    }\n                )\n        else:\n            if role == \"system\" and msg_has_sys:\n                msg_has_sys = False\n                continue\n            messages.append({\"role\": role, \"content\": content})\n    return messages\n\n\ndef generate_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict):\n    for response in generate_stream_chatglm3(model, tokenizer, params):\n        pass\n    return response\n\n\ndef apply_stopping_strings(reply, stop_strings) -> Tuple[str, bool]:\n    stop_found = False\n    for string in stop_strings:\n        idx = reply.find(string)\n        if idx != -1:\n            reply = reply[:idx]\n            stop_found = True\n            break\n\n    if not stop_found:\n        # If something like \"\\nYo\" is generated just before \"\\nYou: is completed, trim it\n        for string in stop_strings:\n            for j in range(len(string) - 1, 0, -1):\n                if reply[-j:] == string[:j]:\n                    reply = reply[:-j]\n                    break\n            else:\n                continue\n\n            break\n\n    return reply, stop_found\n"
  },
  {
    "path": "requirements.txt",
    "content": "# basic requirements\n\ntransformers==4.40.0\ncpm_kernels>=1.0.11\ntorch>=2.3.0\nvllm>=0.4.2\ngradio>=4.26.0\nsentencepiece>=0.2.0\nsentence_transformers>=2.7.0\naccelerate>=0.29.2\nstreamlit>=1.33.0\nfastapi>=0.110.0\nloguru~=0.7.2\nmdtex2html>=1.3.0\nlatex2mathml>=3.77.0\njupyter_client>=8.6.1\n\n# for openai demo\nopenai>=1.30.1\npydantic>=2.7.1\nsse-starlette>=2.1.0\nuvicorn>=0.29.0\ntimm>=0.9.16\ntiktoken>=0.6.0\n\n# for langchain demo\n\nlangchain>=0.2.1\nlangchain_community>=0.2.0\nlangchainhub>=0.1.15\narxiv>=2.1.0"
  },
  {
    "path": "resources/WECHAT.md",
    "content": "<div align=\"center\">\n<img src=wechat.jpg width=\"60%\"/>\n\n<p> 扫码关注公众号，加入「ChatGLM交流群」 </p>\n<p> Scan the QR code to follow the official account and join the \"ChatGLM Discussion Group\" </p>\n</div>\n\n"
  },
  {
    "path": "tensorrt_llm_demo/README.md",
    "content": "# 使用NVIDIA TensorRT-LLM部署ChatGLM3\n\n[TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/main)是NVIDIA开发的高性能推理框架，您可以按照以下步骤来使用TensorRT-LLM部署ChatGLM3模型。\n\n## 1. 安装TensorRT-LLM\n#### 获取TensorRT-LLM代码：\n\n```bash\n# TensorRT-LLM 代码需要使用 git-lfs 拉取\napt-get update && apt-get -y install git git-lfs\n\ngit clone https://github.com/NVIDIA/TensorRT-LLM.git\ncd TensorRT-LLM\n\n# 本流程将使用 v0.7.0 Release 版本\ngit checkout tags/v0.7.0 -b release/0.7.0\ngit submodule update --init --recursive\ngit lfs install\ngit lfs pull\n```\n\n#### 构建docker镜像并安装TensorRT-LLM：\n```bash\nmake -C docker release_build\n```\n\n#### 运行docker镜像：\n```bash\nmake -C docker release_run\n```\n\n## 3. 为ChatGLM3模型构建TensorRT-LLM推理引擎：\n\n#### 安装Python依赖：\n```bash\ncd ./examples/chatglm\npip install -r requirements.txt\napt-get update\napt-get install git-lfs\n```\n#### 从Huggingface下载ChatGLM3模型：\n```\n# 您可以选择具体想部署的模型下载\ngit clone https://huggingface.co/THUDM/chatglm3-6b      chatglm3_6b\ngit clone https://huggingface.co/THUDM/chatglm3-6b-base chatglm3_6b_base\ngit clone https://huggingface.co/THUDM/chatglm3-6b-32k  chatglm3_6b_32k\n```\n\n#### 使用build.py来构建推理引擎：\n以下是一些使用build.py构建推理引擎的示例：\n```bash\n# 构建一个默认的精度为fp16的引擎\npython3 build.py -m chatglm3_6b --output_dir trt_engines/chatglm3_6b/fp16/1-gpu\n\n# 构建一个默认的精度为fp16的引擎，并打开FMHA功能（详见下文）\npython3 build.py -m chatglm3_6b --enable_context_fmha --output_dir trt_engines/chatglm3_6b/fp16/1-gpu\n\n# 构建一个w8a16的引擎\npython3 build.py -m chatglm3_6b --use_weight_only --output_dir trt_engines/chatglm3_6b/weight_only/1-gpu\n\n# 构建一个默认的精度为fp16的引擎，并支持使用两个GPU\npython3 build.py -m chatglm3_6b --world_size 2 --output_dir trt_engines/chatglm3_6b/fp16/2-gpu\n\n# 使用chatglm3_6b_base模型\npython3 build.py -m chatglm3_6b_base --output_dir trt_engines/chatglm3_6b_base/fp16/1-gpu\n\n# 使用chatglm3_6b-32k模型\npython3 build.py -m chatglm3_6b_32k --output_dir trt_engines/chatglm3_6b-32k/fp16/1-gpu\n```\n\n#### 可配置的plugin参数\n\n* 使用 `--use_gpt_attention_plugin <DataType>` 来配置 GPT Attention plugin (默认使用float16)。\n* 使用 `--use_gemm_plugin <DataType>` 来配置 GEMM plugin (默认使用float16)。\n* 使用 `--use_rmsnorm_plugin <DataType>` 来配置 RMS normolization plugin (默认使用float16)。\n\n#### Fused Multi-Head Attention (FMHA)\n\n* 使用 `--enable_context_fmha` 或 `--enable_context_fmha_fp32_acc` 参数来开启FMHA kernels, 可以获得更好的性能的同时降低显存开销。\n\n* `--use_gpt_attention_plugin` 如果被设置为关闭的话将无法使用FMHA功能。\n\n* `--enable_context_fmha` 将会使用FP16 accumulator, 可能会略微降低精度. 您也可以选择使用`--enable_context_fmha_fp32_acc` 来保护精度，但这会略微降低FMHA的性能提升。\n\n#### Weight-Only 量化\n\n* 使用 `--use_weight_only` 来开启 Weight-Only 量化, 这样可以加速推理并减少显存开销。\n\n* 你还可以通过切换 `--weight_only_precision int8` 或者 `--weight_only_precision int4` 来选择具体是使用int8还是int4量化，默认为Int8。\n\n#### In-flight Batching（须使用NVIDIA Triton进行推理）\n\n* 使用 `--use_inflight_batching` 来开启 In-flight Batching，启用后，Paged KV Cache也会自动启用。\n\n* Paged KV cache中block的数量可以用`--tokens_per_block` 来配置。\n\n更多详细的功能和配置请参考：[TensorRT-LLM ChatGLM实现](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/chatglm)。\n\n## 3. 使用TensorRT-LLM Python Runtime进行推理\n\n#### 单机单卡的推理示例：\n\n```bash\npython3 ../run.py --input_text \"What's new between ChatGLM3-6B and ChatGLM2-6B?\" \\\n                  --max_output_len 50 \\\n                  --tokenizer_dir chatglm3_6b \\\n                  --engine_dir trt_engines/chatglm3_6b/fp16/1-gpu\n```\n\n#### 单机多卡的推理示例：\n\n```bash\nmpirun -n 2 \\\n    python ../run.py --input_text \"What's new between ChatGLM3-6B and ChatGLM2-6B?\" \\\n                     --max_output_len 50 \\\n                     --tokenizer_dir chatglm3_6b \\\n                     --engine_dir trt_engines/chatglm3_6b/fp16/1-gpu\n```\n\n* 如果您以root权限运行 `mpirun`，则可能需要添加 `--allow-run-as-root` 参数。\n\n#### 运行summarize.py进行文章总结任务：\n\n```bash\npython3 ../summarize.py --test_trt_llm \\\n                        --hf_model_dir chatglm3_6b \\\n                        --engine_dir trt_engines/chatglm3_6b/fp16/1-gpu\n```\n\n#### 运行我们提供的对话demo:[tensorrt_llm_cli_demo.py](tensorrt_llm_cli_demo.py):\n\n```bash\npython3 tensorrt_llm_cli_demo.py --tokenizer_dir chatglm3_6b --engine_dir trt_engines/chatglm3_6b/fp16/1-gpu\n```\n\n运行结果展示：\n```\n用户: what is your name?\nChatGLM3-6B:Hello, I am an assistant named ChatGLM3-6B, and you can call me assistant. What can I help you with??\n\n用户: what is new in ChatGLM3-6B compared with ChatGLM2-6B?\nChatGLM3-6B:ChatGLM3-6B is an improved version of ChatGLM2-6B. Compared with ChatGLM2-6B, ChatGLM3-6B has the following improvements:\n\n1. Enhanced language understanding capabilities: ChatGLM3-6B's language model is based on the GLM3-6B model, which has been pre-trained on more diverse and large-scale data, resulting in better language understanding and generation capabilities.\n\n2. Improved generation ability: ChatGLM3-6B has improved the generation ability compared to ChatGLM2-6B. With more training data and optimization algorithms, ChatGLM3-6B can generate more coherent and natural-looking text.\n\n3. Enhanced adaptability to different dialogue scenarios: ChatGLM3-6B has been trained on more diverse dialogue data, including dialogue scenarios with different languages, cultures, and styles, making it more adaptable to different dialogue scenarios.\n\n4. New features and functions: ChatGLM3-6B also has some new features and functions, such as support for multiple choice questions, sentiment analysis, and entity recognition.\n\nIn short, ChatGLM3-6B is more advanced and capable than ChatGLM2-6B, and can better meet the needs of users in various scenarios..\n```\n\n#### 性能测试：\n\n您可以在[这里](https://github.com/NVIDIA/TensorRT-LLM/tree/main/benchmarks/python)查阅到如何测试 TensorRT-LLM 上运行 ChatGLM3 的性能。\n\n## 4. 使用NVIDIA Triton部署在线推理服务器\n使用 NVIDIA Triton 可以部署高性能，高拓展性，高稳定性的推理服务，并且可以开启In-flight Batching功能提升实际推理服务时的吞吐。详见[In-flight Batching Triton Backend](https://github.com/triton-inference-server/tensorrtllm_backend/tree/main/inflight_batcher_llm)。"
  },
  {
    "path": "tensorrt_llm_demo/tensorrt_llm_cli_demo.py",
    "content": "\"\"\"\nThis script is a part of a larger project for generating text using large language models.\nIt includes functionalities for finding engine files, parsing arguments, setting up configurations for different models,\nand executing the generation process with various settings.\nThis script particularly supports models like ChatGLM3-6B and its variants,\nhandling quantization, serialization, and runtime aspects.\n\n\nCopyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\nModifications made by Yuxuan.Zhang @ ZhipuAI on 2023-12-24.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\nModifications:\n\n1. Removed input_file, tokenizer_type, and other parameters unrelated to dialogue. Set num_beams to 1.\n2. Adapted single turn dialogue into ChatGLM3-6B template and implemented multi-turn conversations.\n\n\"\"\"\n\nimport argparse\nimport json\nimport torch\nimport transformers\n\nfrom pathlib import Path\nfrom typing import List\n\nimport tensorrt_llm\nfrom tensorrt_llm.quantization import QuantMode\nfrom tensorrt_llm.runtime import (GenerationSession, ModelConfig, SamplingConfig)\n\n\ndef find_engines(dir: Path, model_name: str = \"*\", dtype: str = \"*\", tp_size: str = \"*\", rank: str = \"*\") -> List[Path]:\n    \"\"\"\n    Searches for engine files matching a specified pattern within a directory.\n    This is typically used to locate compiled model files for efficient execution on specific hardware.\n    Parameters:\n        - dir: The directory to search.\n        - model_name, dtype, tp_size, rank:\n        Pattern matching parameters to filter engine files by model name, data type,\n        tensor parallel size, and rank respectively.\n    Returns:\n        - A list of Paths pointing to the engine files.\n    \"\"\"\n\n    template = f\"{model_name}_{dtype}_tp{tp_size}_rank{rank}.engine\"\n    return list(dir.glob(template))\n\n\ndef parse_arguments(args=None):\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--model_name',\n                        type=str,\n                        choices=[\n                            \"chatglm3_6b\",\n                            \"chatglm3_6b_base\",\n                            \"chatglm3_6b_32k\"\n                        ],\n                        default=\"chatglm3_6b\",\n                        help='the name of the model')\n    parser.add_argument('--max_output_len', type=int, default=4096)\n    parser.add_argument('--engine_dir', type=str, default=None)\n    parser.add_argument('--tokenizer_dir', type=str, default=None)\n    parser.add_argument('--temperature', type=float, default=0.95)\n    parser.add_argument('--top_k', type=int, default=1)\n    parser.add_argument('--top_p', type=float, default=0.8)\n    parser.add_argument('--random_seed', type=int, default=2023)\n    parser.add_argument('--streaming', default=True, action='store_true')\n    args = parser.parse_args(args)\n\n    return args\n\n\ndef main():\n    \"\"\"\n    The main execution function of the script. It orchestrates the text generation process\n    by performing several key steps:\n        - Parses command-line arguments to configure model details, output specifications,\n        and other user-defined parameters.\n        - Loads the model configuration from a specified directory and prepares the environment for text generation\n        based on the model and hardware specifics.\n        - Sets up the generation session with the appropriate model, tokenizer, and runtime configurations.\n        - Enters a loop to continuously accept user input, generate text based on the provided prompts, and output\n        the model's responses.\n        - Handles special commands such as 'stop' to end the conversation and 'clear' to reset the chat history.\n        - Manages resources and ensures that the generated text is properly formatted and presented to the user.\n    The function is designed to be the entry point of the script, invoking all necessary components and managing the\n    flow of data and control throughout the execution.\n    \"\"\"\n\n    args = parse_arguments()\n\n    config_path = Path(args.engine_dir) / 'config.json'\n    with open(config_path, 'r') as f:\n        config = json.load(f)\n\n    dtype = config['builder_config']['precision']\n    max_output_len = min(config['builder_config']['max_output_len'], args.max_output_len)\n    use_gpt_attention_plugin = config['plugin_config']['gpt_attention_plugin']\n    remove_input_padding = config['builder_config']['remove_input_padding']\n    tp_size = config['builder_config']['tensor_parallel']\n    pp_size = config['builder_config']['pipeline_parallel']\n    world_size = tp_size * pp_size\n\n    assert world_size == tensorrt_llm.mpi_world_size(), f'Engine world size ({tp_size} * {pp_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})'\n\n    max_output_len = min(max_output_len, args.max_output_len)\n    runtime_rank = tensorrt_llm.mpi_rank()\n    runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank, tp_size=world_size)\n    torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)\n\n    serialize_path = find_engines(\n        dir=Path(args.engine_dir),\n        model_name=args.model_name,\n        dtype=dtype,\n        tp_size=world_size,\n        rank=runtime_rank)[0]\n\n    tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer_dir, trust_remote_code=True)\n    model_config = ModelConfig(vocab_size=config['builder_config']['vocab_size'],\n                               num_layers=config['builder_config']['num_layers'],\n                               num_heads=config['builder_config']['num_heads'] // tp_size,\n                               num_kv_heads=(config['builder_config']['num_kv_heads'] + tp_size - 1) // tp_size,\n                               hidden_size=config['builder_config']['hidden_size'] // tp_size,\n                               gpt_attention_plugin=use_gpt_attention_plugin,\n                               remove_input_padding=config['builder_config']['remove_input_padding'],\n                               model_name=args.model_name,\n                               paged_kv_cache=config['builder_config']['paged_kv_cache'],\n                               quant_mode=QuantMode(config['builder_config']['quant_mode']),\n                               dtype=dtype)\n\n    sampling_config = SamplingConfig(\n        end_id=tokenizer.eos_token_id,\n        pad_id=tokenizer.pad_token_id,\n        num_beams=1,\n        temperature=args.temperature,\n        top_k=args.top_k,\n        top_p=args.top_p\n    )\n    sampling_config.random_seed = args.random_seed\n\n    with open(serialize_path, 'rb') as f:\n        engine_buffer = f.read()\n        session = GenerationSession\n\n    decoder = session(model_config, engine_buffer, runtime_mapping)\n\n    history = []\n    while True:\n        input_text_with_history = \"\"\n        max_input_len = config['builder_config']['max_input_len']\n        input_text = input(\"用户: \")\n        if input_text.lower() == 'stop':\n            break\n\n        if input_text.lower() == 'clear':\n            history = []\n            print(\"ChatGLM3-6B: 对话历史已清空\")\n            continue\n\n        history.append(input_text)\n\n        for idx, content in enumerate(history):\n            if idx % 2 != 0:\n                input_text_with_history += \"{}\\n\".format(content)\n            else:\n                input_text_with_history += \"<|user|>{}\\n<|assistant|>\".format(content)\n\n        tokenized = tokenizer(\n            input_text_with_history,\n            return_tensors=\"pt\",\n            padding=True,\n            return_length=True\n        )\n\n        input_ids = tokenized['input_ids'].int()\n        input_lengths = tokenized['length'].int()\n        max_input_len_real = torch.max(input_lengths)\n        if max_input_len_real > max_input_len:\n            input_ids = input_ids[:, :max_input_len]\n            input_lengths = torch.where(input_lengths > max_input_len, max_input_len, input_lengths)\n        else:\n            max_input_len = max_input_len_real\n        if remove_input_padding:\n            input_ids_no_padding = (torch.zeros(1, torch.sum(input_lengths), dtype=torch.int32))\n\n            lengths_acc = torch.cumsum(torch.cat([torch.IntTensor([0]), input_lengths]), dim=0)\n\n            for i in range(len(input_ids)):\n                input_ids_no_padding[0, lengths_acc[i]:lengths_acc[i + 1]] = torch.IntTensor(\n                    input_ids[i, max_input_len - input_lengths[i]:max_input_len])\n\n            input_ids = input_ids_no_padding\n\n        elif use_gpt_attention_plugin:\n            input_ids_padding_right = torch.zeros_like(input_ids) + sampling_config.end_id\n            for i, sample in enumerate(input_ids):\n                nPadding = 0\n                for token in sample:\n                    if token == sampling_config.pad_id:\n                        nPadding += 1\n                    else:\n                        break\n                input_ids_padding_right[i, :len(sample[nPadding:])] = sample[nPadding:]\n            input_ids = input_ids_padding_right\n        input_lengths = torch.tensor([input_ids.shape[-1]], dtype=torch.int32)\n        decoder.setup(1, max_input_len, max_output_len, 1)\n        output = decoder.decode(\n            input_ids.contiguous().cuda(),\n            input_lengths.contiguous().cuda(),\n            sampling_config,\n            output_sequence_lengths=True,\n            return_dict=True,\n            streaming=args.streaming\n        )\n\n        print(\"ChatGLM3-6B:\", end=\"\")\n        generated_text = \"\"\n        if args.streaming:\n            for output_item in output:\n                output_id = output_item[\"output_ids\"]\n                output_sequence_lengths = output_item[\"sequence_lengths\"]\n                output_id = output_id[0, 0, output_sequence_lengths[0, 0] - 1]\n                output_word = tokenizer.convert_ids_to_tokens(int(output_id))\n                output_word = output_word.replace(\"▁\", \" \")\n                output_word = tokenizer.convert_tokens_to_string(output_word)\n                print(output_word, end=\"\", flush=True)\n                generated_text += output_word\n            print(\"\\n\")\n        else:\n            torch.cuda.synchronize()\n            output_ids = output[\"output_ids\"][0]\n            output = output_ids[0, input_lengths.item():]\n            generated_text = tokenizer.decode(output, skip_special_tokens=True)\n            print(generated_text)\n\n        history.append(generated_text)\n\n    del decoder\n    print(f\"Good bye!\")\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "tools_using_demo/README.md",
    "content": "# 工具调用\n本文档将介绍如何使用 ChatGLM3-6B 进行工具调用。目前只有 ChatGLM3-6B 模型支持工具调用，而 ChatGLM3-6B-Base 和 ChatGLM3-6B-32K 模型不支持。\n\n## 构建 System Prompt\n这里以两个工具调用为例，首先准备好要构建的数据的描述信息。\n\n```python\ntools = [\n    {\n        \"name\": \"track\",\n        \"description\": \"追踪指定股票的实时价格\",\n        \"parameters\": {\n            \"type\": \"object\",\n            \"properties\": {\n                \"symbol\": {\n                    \"description\": \"需要追踪的股票代码\"\n                }\n            },\n            \"required\": ['symbol']\n        }\n    },\n    {\n        \"name\": \"text-to-speech\",\n        \"description\": \"将文本转换为语音\",\n        \"parameters\": {\n            \"type\": \"object\",\n            \"properties\": {\n                \"text\": {\n                    \"description\": \"需要转换成语音的文本\"\n                },\n                \"voice\": {\n                    \"description\": \"要使用的语音类型（男声、女声等）\"\n                },\n                \"speed\": {\n                    \"description\": \"语音的速度（快、中等、慢等）\"\n                }\n            },\n            \"required\": ['text']\n        }\n    }\n]\nsystem_info = {\"role\": \"system\", \"content\": \"Answer the following questions as best as you can. You have access to the following tools:\", \"tools\": tools}\n```\n请确保工具的定义格式与例子中一致以获得最优的性能\n\n## 提出问题\n注意：目前 ChatGLM3-6B 的工具调用只支持通过 `chat` 方法，不支持  `stream_chat` 方法。\n```python\nhistory = [system_info]\nquery = \"帮我查询股票10111的价格\"\nresponse, history = model.chat(tokenizer, query, history=history)\nprint(response)\n```\n这里期望得到的输出为\n```json\n{\"name\": \"track\", \"parameters\": {\"symbol\": \"10111\"}}\n```\n这表示模型需要调用工具 `track`，并且需要传入参数 `symbol`。\n\n## 调用工具，生成回复\n这里需要自行实现调用工具的逻辑。假设已经得到了返回结果，将结果以 json 格式返回给模型并得到回复。\n```python\nresult = json.dumps({\"price\": 12412}, ensure_ascii=False)\nresponse, history = model.chat(tokenizer, result, history=history, role=\"observation\")\nprint(response)\n```\n这里 `role=\"observation\"` 表示输入的是工具调用的返回值而不是用户输入，不能省略。\n\n期望得到的输出为\n```\n根据您的查询，经过API的调用，股票10111的价格是12412。\n```\n\n这表示本次工具调用已经结束，模型根据返回结果生成回复。对于比较复杂的问题，模型可能需要进行多次工具调用。这时，可以根据返回的 `response` 是 `str` 还是 `dict` 来判断返回的是生成的回复还是工具调用请求。"
  },
  {
    "path": "tools_using_demo/README_en.md",
    "content": "# Tool Invocation\nThis document will introduce how to use the ChatGLM3-6B for tool invocation. Currently, only the ChatGLM3-6B model supports tool invocation, while the ChatGLM3-6B-Base and ChatGLM3-6B-32K models do not support it.\n\n## Building System Prompt\nHere are two examples of tool invocation. First, prepare the description information of the data to be built.\n\n```python\ntools = [\n    {\n        \"name\": \"track\",\n        \"description\": \"Track the real-time price of a specified stock\",\n        \"parameters\": {\n            \"type\": \"object\",\n            \"properties\": {\n                \"symbol\": {\n                    \"description\": \"The stock code that needs to be tracked\"\n                }\n            },\n            \"required\": ['symbol']\n        }\n    },\n    {\n        \"name\": \"text-to-speech\",\n        \"description\": \"Convert text to speech\",\n        \"parameters\": {\n            \"type\": \"object\",\n            \"properties\": {\n                \"text\": {\n                    \"description\": \"The text that needs to be converted into speech\"\n                },\n                \"voice\": {\n                    \"description\": \"The type of voice to use (male, female, etc.)\"\n                },\n                \"speed\": {\n                    \"description\": \"The speed of the speech (fast, medium, slow, etc.)\"\n                }\n            },\n            \"required\": ['text']\n        }\n    }\n]\nsystem_info = {\"role\": \"system\", \"content\": \"Answer the following questions as best as you can. You have access to the following tools:\", \"tools\": tools}\n```\n\nPlease ensure that the definition format of the tool is consistent with the example to obtain optimal performance.\n\n## Asking Questions\nNote: Currently, the tool invocation of ChatGLM3-6B only supports the `chat` method and does not support the `stream_chat` method.\n```python\nhistory = [system_info]\nquery = \"Help me inquire the price of stock 10111\"\nresponse, history = model.chat(tokenizer, query, history=history)\nprint(response)\n```\nThe expected output here is\n```json\n{\"name\": \"track\", \"parameters\": {\"symbol\": \"10111\"}}\n```\nThis indicates that the model needs to call the tool `track`, and the parameter `symbol` needs to be passed in.\n\n## Invoke Tool, Generate Response\nHere, you need to implement the logic of calling the tool yourself. Assuming that the return result has been obtained, return the result to the model in json format and get a response.\n```python\nresult = json.dumps({\"price\": 12412}, ensure_ascii=False)\nresponse, history = model.chat(tokenizer, result, history=history, role=\"observation\")\nprint(response)\n```\nHere `role=\"observation\"` indicates that the input is the return value of the tool invocation rather than user input, and it cannot be omitted.\n\nThe expected output is\n```\nBased on your query, after the API call, the price of stock 10111 is 12412.\n```\n\nThis indicates that this tool invocation has ended, and the model generates a response based on the return result. For more complex questions, the model may need to make multiple tool invocations. At this time, you can judge whether the returned `response` is `str` or `dict` to determine whether the return is a generated response or a tool invocation request."
  },
  {
    "path": "tools_using_demo/cli_demo_tool.py",
    "content": "\"\"\"\nThis demo script is designed for interacting with the ChatGLM3-6B in Function, to show Function Call capabilities.\n\"\"\"\n\nimport os\nimport platform\nimport torch\nfrom transformers import AutoTokenizer, AutoModel\n\nMODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')\nTOKENIZER_PATH = os.environ.get(\"TOKENIZER_PATH\", MODEL_PATH)\n\ntokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)\nmodel = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map=\"auto\").eval()\n\nos_name = platform.system()\nclear_command = 'cls' if os_name == 'Windows' else 'clear'\nstop_stream = False\n\n\ndef build_prompt(history):\n    prompt = \"欢迎使用 ChatGLM3-6B 模型，输入内容即可进行对话，clear 清空对话历史，stop 终止程序\"\n    for query, response in history:\n        prompt += f\"\\n\\n用户：{query}\"\n        prompt += f\"\\n\\nChatGLM3-6B：{response}\"\n    return prompt\n\n\ntools = [\n    {'name': 'track', 'description': '追踪指定股票的实时价格',\n     'parameters':\n         {\n             'type': 'object', 'properties':\n             {'symbol':\n                 {\n                     'description': '需要追踪的股票代码'\n                 }\n             },\n             'required': []\n         }\n     }, {\n        'name': '/text-to-speech', 'description': '将文本转换为语音',\n        'parameters':\n            {\n                'type': 'object', 'properties':\n                {\n                    'text':\n                        {\n                            'description': '需要转换成语音的文本'\n                        },\n                    'voice':\n                        {\n                            'description': '要使用的语音类型（男声、女声等）'\n                        },\n                    'speed': {\n                        'description': '语音的速度（快、中等、慢等）'\n                    }\n                }, 'required': []\n            }\n    },\n    {\n        'name': '/image_resizer', 'description': '调整图片的大小和尺寸',\n        'parameters': {'type': 'object',\n                       'properties':\n                           {\n                               'image_file':\n                                   {\n                                       'description': '需要调整大小的图片文件'\n                                   },\n                               'width':\n                                   {\n                                       'description': '需要调整的宽度值'\n                                   },\n                               'height':\n                                   {\n                                       'description': '需要调整的高度值'\n                                   }\n                           },\n                       'required': []\n                       }\n    },\n    {\n        'name': '/foodimg', 'description': '通过给定的食品名称生成该食品的图片',\n        'parameters': {\n            'type': 'object', 'properties':\n                {\n                    'food_name':\n                        {\n                            'description': '需要生成图片的食品名称'\n                        }\n                },\n            'required': []\n        }\n    }\n]\nsystem_item = {\n    \"role\": \"system\",\n    \"content\": \"Answer the following questions as best as you can. You have access to the following tools:\",\n    \"tools\": tools\n}\n\n\ndef main():\n    past_key_values, history = None, [system_item]\n    role = \"user\"\n    global stop_stream\n    print(\"欢迎使用 ChatGLM3-6B 模型，输入内容即可进行对话，clear 清空对话历史，stop 终止程序\")\n    while True:\n        query = input(\"\\n用户：\") if role == \"user\" else input(\"\\n结果：\")\n        if query.strip() == \"stop\":\n            break\n        if query.strip() == \"clear\":\n            past_key_values, history = None, [system_item]\n            role = \"user\"\n            os.system(clear_command)\n            print(\"欢迎使用 ChatGLM3-6B 模型，输入内容即可进行对话，clear 清空对话历史，stop 终止程序\")\n            continue\n        print(\"\\nChatGLM：\", end=\"\")\n        response, history = model.chat(tokenizer, query, history=history, role=role)\n        print(response, end=\"\", flush=True)\n        print(\"\")\n        if isinstance(response, dict):\n            role = \"observation\"\n        else:\n            role = \"user\"\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "tools_using_demo/openai_api_demo.py",
    "content": "import json\n\nfrom openai import OpenAI\nfrom colorama import init, Fore\nfrom loguru import logger\n\nfrom tool_register import get_tools, dispatch_tool\n\ninit(autoreset=True)\nclient = OpenAI(\n  base_url=\"http://127.0.0.1:8000/v1\",\n  api_key = \"xxx\"\n)\n\ntools = get_tools()\n\n\ndef run_conversation(query: str, stream=False, tools=None, max_retry=5):\n    params = dict(model=\"chatglm3\", messages=[{\"role\": \"user\", \"content\": query}], stream=stream)\n    if tools:\n        params[\"tools\"] = tools\n    response = client.chat.completions.create(**params)\n\n    for _ in range(max_retry):\n        if not stream:\n            if response.choices[0].message.function_call:\n                function_call = response.choices[0].message.function_call\n                logger.info(f\"Function Call Response: {function_call.model_dump()}\")\n\n                function_args = json.loads(function_call.arguments)\n                tool_response = dispatch_tool(function_call.name, function_args)\n                logger.info(f\"Tool Call Response: {tool_response}\")\n\n                params[\"messages\"].append(response.choices[0].message)\n                params[\"messages\"].append(\n                    {\n                        \"role\": \"function\",\n                        \"name\": function_call.name,\n                        \"content\": tool_response,  # 调用函数返回结果\n                    }\n                )\n            else:\n                reply = response.choices[0].message.content\n                logger.info(f\"Final Reply: \\n{reply}\")\n                return\n\n        else:\n            output = \"\"\n            for chunk in response:\n                content = chunk.choices[0].delta.content or \"\"\n                print(Fore.BLUE + content, end=\"\", flush=True)\n                output += content\n\n                if chunk.choices[0].finish_reason == \"stop\":\n                    return\n\n                elif chunk.choices[0].finish_reason == \"function_call\":\n                    print(\"\\n\")\n\n                    function_call = chunk.choices[0].delta.function_call\n                    logger.info(f\"Function Call Response: {function_call.model_dump()}\")\n\n                    function_args = json.loads(function_call.arguments)\n                    tool_response = dispatch_tool(function_call.name, function_args)\n                    logger.info(f\"Tool Call Response: {tool_response}\")\n\n                    params[\"messages\"].append(\n                        {\n                            \"role\": \"assistant\",\n                            \"content\": output\n                        }\n                    )\n                    params[\"messages\"].append(\n                        {\n                            \"role\": \"function\",\n                            \"name\": function_call.name,\n                            \"content\": tool_response,\n                        }\n                    )\n\n                    break\n\n        response = client.chat.completions.create(**params)\n\n\nif __name__ == \"__main__\":\n    query = \"你是谁\"\n    run_conversation(query, stream=True)\n\n    logger.info(\"\\n=========== next conversation ===========\")\n\n    query = \"帮我查询北京的天气怎么样\"\n    run_conversation(query, tools=tools, stream=True)\n"
  },
  {
    "path": "tools_using_demo/tool_register.py",
    "content": "\"\"\"\n这段代码是工具注册的部分，通过注册工具，让模型实现工具调用\n\"\"\"\nimport inspect\nimport traceback\nfrom copy import deepcopy\nfrom pprint import pformat\nfrom types import GenericAlias\nfrom typing import get_origin, Annotated\n\n_TOOL_HOOKS = {}\n_TOOL_DESCRIPTIONS = {}\n\n\ndef register_tool(func: callable):\n    tool_name = func.__name__\n    tool_description = inspect.getdoc(func).strip()\n    python_params = inspect.signature(func).parameters\n    tool_params = []\n    for name, param in python_params.items():\n        annotation = param.annotation\n        if annotation is inspect.Parameter.empty:\n            raise TypeError(f\"Parameter `{name}` missing type annotation\")\n        if get_origin(annotation) != Annotated:\n            raise TypeError(f\"Annotation type for `{name}` must be typing.Annotated\")\n\n        typ, (description, required) = annotation.__origin__, annotation.__metadata__\n        typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__\n        if not isinstance(description, str):\n            raise TypeError(f\"Description for `{name}` must be a string\")\n        if not isinstance(required, bool):\n            raise TypeError(f\"Required for `{name}` must be a bool\")\n\n        tool_params.append({\n            \"name\": name,\n            \"description\": description,\n            \"type\": typ,\n            \"required\": required\n        })\n    tool_def = {\n        \"name\": tool_name,\n        \"description\": tool_description,\n        \"params\": tool_params\n    }\n\n    print(\"[registered tool] \" + pformat(tool_def))\n    _TOOL_HOOKS[tool_name] = func\n    _TOOL_DESCRIPTIONS[tool_name] = tool_def\n\n    return func\n\n\ndef dispatch_tool(tool_name: str, tool_params: dict) -> str:\n    if tool_name not in _TOOL_HOOKS:\n        return f\"Tool `{tool_name}` not found. Please use a provided tool.\"\n    tool_call = _TOOL_HOOKS[tool_name]\n    try:\n        ret = tool_call(**tool_params)\n    except:\n        ret = traceback.format_exc()\n    return str(ret)\n\n\ndef get_tools() -> dict:\n    return deepcopy(_TOOL_DESCRIPTIONS)\n\n\n# tools Definitions\n\n@register_tool\ndef random_number_generator(\n        seed: Annotated[int, 'The random seed used by the generator', True],\n        range: Annotated[tuple[int, int], 'The range of the generated numbers', True],\n) -> int:\n    \"\"\"\n    Generates a random number x, s.t. range[0] <= x < range[1]\n    \"\"\"\n    if not isinstance(seed, int):\n        raise TypeError(\"Seed must be an integer\")\n    if not isinstance(range, tuple):\n        raise TypeError(\"Range must be a tuple\")\n    if not isinstance(range[0], int) or not isinstance(range[1], int):\n        raise TypeError(\"Range must be a tuple of integers\")\n\n    import random\n    return random.Random(seed).randint(*range)\n\n\n@register_tool\ndef get_weather(\n        city_name: Annotated[str, 'The name of the city to be queried', True],\n) -> str:\n    \"\"\"\n    Get the current weather for `city_name`\n    \"\"\"\n\n    if not isinstance(city_name, str):\n        raise TypeError(\"City name must be a string\")\n\n    key_selection = {\n        \"current_condition\": [\"temp_C\", \"FeelsLikeC\", \"humidity\", \"weatherDesc\", \"observation_time\"],\n    }\n    import requests\n    try:\n        resp = requests.get(f\"https://wttr.in/{city_name}?format=j1\")\n        resp.raise_for_status()\n        resp = resp.json()\n        ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()}\n    except:\n        import traceback\n        ret = \"Error encountered while fetching weather data!\\n\" + traceback.format_exc()\n\n    return str(ret)\n\n\nif __name__ == \"__main__\":\n    print(dispatch_tool(\"get_weather\", {\"city_name\": \"beijing\"}))\n    print(get_tools())\n"
  },
  {
    "path": "update_requirements.sh",
    "content": "#!/bin/bash\n\npython -m pip install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple\n\nwhile read requirement; do\n    python -m pip install --upgrade \"$requirement\" -i https://pypi.tuna.tsinghua.edu.cn/simple\ndone < requirements.txt\n"
  }
]