Full Code of zai-org/ChatGLM-6B for AI

main 401bf3a8a7dd cached
37 files
352.1 KB
100.3k tokens
106 symbols
1 requests
Download .txt
Showing preview only (366K chars total). Download the full file or copy to clipboard to get everything.
Repository: zai-org/ChatGLM-6B
Branch: main
Commit: 401bf3a8a7dd
Files: 37
Total size: 352.1 KB

Directory structure:
gitextract_3917dvmb/

├── .github/
│   └── ISSUE_TEMPLATE/
│       ├── bug_report.yaml
│       ├── config.yml
│       └── feature_request.yml
├── .gitignore
├── FAQ.md
├── LICENSE
├── MODEL_LICENSE
├── PROJECT.md
├── README.md
├── README_en.md
├── UPDATE.md
├── api.py
├── cli_demo.py
├── cli_demo_vision.py
├── improve/
│   ├── README.md
│   └── data_sample.jsonl
├── ptuning/
│   ├── README.md
│   ├── README_en.md
│   ├── arguments.py
│   ├── deepspeed.json
│   ├── ds_train_finetune.sh
│   ├── evaluate.sh
│   ├── evaluate_finetune.sh
│   ├── main.py
│   ├── train.sh
│   ├── train_chat.sh
│   ├── trainer.py
│   ├── trainer_seq2seq.py
│   ├── web_demo.py
│   └── web_demo.sh
├── requirements.txt
├── resources/
│   └── WECHAT.md
├── utils.py
├── web_demo.py
├── web_demo2.py
├── web_demo_old.py
└── web_demo_vision.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.yaml
================================================
name: 🐞 Bug/Help
description: File a bug/issue
title: "[BUG/Help] <title>"
labels: []
body:
- type: checkboxes
  attributes:
    label: Is there an existing issue for this?
    description: Please search to see if an issue already exists for the bug you encountered.
    options:
    - label: I have searched the existing issues
      required: true
- type: textarea
  attributes:
    label: Current Behavior
    description: | 
      A concise description of what you're experiencing, with screenshot attached if possible.
      Tip: You can attach images or log files by clicking this area to highlight it and then dragging files in.
  validations:
    required: true
- type: textarea
  attributes:
    label: Expected Behavior
    description: A concise description of what you expected to happen.
  validations:
    required: false
- type: textarea
  attributes:
    label: Steps To Reproduce
    description: Steps to reproduce the behavior.
    placeholder: |
      1. In this environment...
      2. With this config...
      3. Run '...'
      4. See error...
  validations:
    required: true
- type: textarea
  attributes:
    label: Environment
    description: |
      examples:
        - **OS**: Ubuntu 20.04
        - **Python**: 3.8
        - **Transformers**: 4.26.1
        - **PyTorch**: 1.12
        - **CUDA Support**: True
    value: |
        - OS:
        - Python:
        - Transformers:
        - PyTorch:
        - CUDA Support (`python -c "import torch; print(torch.cuda.is_available())"`) :
    render: markdown
  validations:
    required: true
- type: textarea
  attributes:
    label: Anything else?
    description: |
      Links? References? Anything that will give us more context about the issue you are encountering!
  validations:
    required: false


================================================
FILE: .github/ISSUE_TEMPLATE/config.yml
================================================
blank_issues_enabled: false

================================================
FILE: .github/ISSUE_TEMPLATE/feature_request.yml
================================================
name: Feature request
description: Suggest an idea for this project
title: "[Feature] <title>"
labels: []
body:
- type: textarea
  attributes:
    label: Is your feature request related to a problem? Please describe.
    description: | 
      A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
  validations:
    required: false
- type: textarea
  attributes:
    label: Solutions
    description: |
      Describe the solution you'd like
      A clear and concise description of what you want to happen.
  validations:
    required: true
- type: textarea
  attributes:
    label: Additional context
    description: Add any other context or screenshots about the feature request here.
  validations:
    required: false


================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
history/

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# Mac system file
model/
.idea

================================================
FILE: FAQ.md
================================================
## Q1

**Mac直接加载量化后的模型出现提示 `clang: error: unsupported option '-fopenmp'**

这是由于Mac由于本身缺乏omp导致的,此时可运行但是单核。需要单独安装 openmp 依赖,即可在Mac下使用OMP:

```bash
# 参考`https://mac.r-project.org/openmp/`
## 假设: gcc(clang)是14.x版本,其他版本见R-Project提供的表格
curl -O https://mac.r-project.org/openmp/openmp-14.0.6-darwin20-Release.tar.gz
sudo tar fvxz openmp-14.0.6-darwin20-Release.tar.gz -C /
```
此时会安装下面几个文件:`/usr/local/lib/libomp.dylib`, `/usr/local/include/ompt.h`, `/usr/local/include/omp.h`, `/usr/local/include/omp-tools.h`。

> 注意:如果你之前运行`ChatGLM`项目失败过,最好清一下Huggingface的缓存,i.e. 默认下是 `rm -rf ${HOME}/.cache/huggingface/modules/transformers_modules/chatglm-6b-int4`。由于使用了`rm`命令,请明确知道自己在删除什么。

================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright Zhengxiao Du

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

================================================
FILE: MODEL_LICENSE
================================================
The ChatGLM-6B License

1. 定义

“许可方”是指分发其软件的 ChatGLM-6B 模型团队。

“软件”是指根据本许可提供的 ChatGLM-6B 模型参数。(不包括二代模型 ChatGLM2-6B 以及后续模型)

2. 许可授予

根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可。

上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。

3.限制

您不得出于任何军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。

您不得利用本软件从事任何危害国家安全和国家统一、危害社会公共利益、侵犯人身权益的行为。

4.免责声明

本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。 在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 软件。

5. 责任限制

除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。

6.争议解决

本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。

请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 license@zhipuai.cn 与我们联系。

1. Definitions

“Licensor” means the ChatGLM-6B Model Team that distributes its Software.

“Software” means the ChatGLM-6B model parameters made available under this license (does not include the second-generation model ChatGLM2-6B and subsequent models).

2. License Grant

Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software.

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

3. Restriction

You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any military, or illegal purposes.

You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings.

4. Disclaimer

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

5. Limitation of Liability

EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.

6. Dispute Resolution

This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.

Note that the license is subject to update to a more comprehensive version.  For any questions related to the license and copyright, please contact us at license@zhipuai.cn.


================================================
FILE: PROJECT.md
================================================
# 友情链接

对 ChatGLM 进行加速或者重新实现的开源项目:
* [lyraChatGLM](https://huggingface.co/TMElyralab/lyraChatGLM): 对 ChatGLM-6B 进行推理加速,最高可以实现 9000+ tokens/s 的推理速度
* [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer): 一个Transformer统一编程框架,ChatGLM-6B已经在SAT中进行实现并可以进行P-tuning微调。
* [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): 一个基于 MNN 的 ChatGLM-6B C++ 推理实现,支持根据显存大小自动分配计算任务给 GPU 和 CPU
* [JittorLLMs](https://github.com/Jittor/JittorLLMs):最低3G显存或者没有显卡都可运行 ChatGLM-6B FP16, 支持Linux、windows、Mac部署
* [InferLLM](https://github.com/MegEngine/InferLLM):轻量级 C++ 推理,可以实现本地 x86,Arm 处理器上实时聊天,手机上也同样可以实时运行,运行内存只需要 4G



基于或使用了 ChatGLM-6B 的开源项目:
* [chatgpt_academic](https://github.com/binary-husky/chatgpt_academic): 支持ChatGLM-6B的学术写作与编程工具箱,具有模块化和多线程调用LLM的特点,可并行调用多种LLM。
* [闻达](https://github.com/l15y/wenda):大型语言模型调用平台,基于 ChatGLM-6B 实现了类 ChatPDF 功能
* [glm-bot](https://github.com/initialencounter/glm-bot):将ChatGLM接入Koishi可在各大聊天平台上调用ChatGLM
* [Chinese-LangChain](https://github.com/yanqiangmiffy/Chinese-LangChain):中文langchain项目,基于ChatGLM-6b+langchain实现本地化知识库检索与智能答案生成,增加web search功能、知识库选择功能和支持知识增量更新
* [bibliothecarius](https://github.com/coderabbit214/bibliothecarius):快速构建服务以集成您的本地数据和AI模型,支持ChatGLM等本地化模型接入。
* [langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM):基于 langchain 的 ChatGLM 应用,实现基于可扩展知识库的问答
* [ChatGLM-web](https://github.com/NCZkevin/chatglm-web):基于FastAPI和Vue3搭建的ChatGLM演示网站(支持chatglm流式输出、前端调整模型参数、上下文选择、保存图片、知识库问答等功能)
* [Chuanhu Chat](https://github.com/GaiZhenbiao/ChuanhuChatGPT): 为各个大语言模型和在线模型API提供美观易用、功能丰富、快速部署的用户界面,支持ChatGLM-6B。
* [ChatGLM-6B-Engineering](https://github.com/LemonQu-GIT/ChatGLM-6B-Engineering):基于 ChatGLM-6B 后期调教,网络爬虫及 [Stable Diffusion](https://github.com/AUTOMATIC1111/stable-diffusion-webui) 实现的网络搜索及图片生成
* [ChatGLM-OpenAI-API](https://github.com/ninehills/chatglm-openai-api): 将 ChatGLM-6B 封装为 OpenAI API 风格,并通过 ngrok/cloudflare 对外提供服务,从而将 ChatGLM 快速集成到 OpenAI 的各种生态中。
* [ChatSQL](https://github.com/cubenlp/ChatSQL): 基于ChatGLM+SBERT实现NL2SQL本地化,并直接连接数据库查询数据返回结果,使得生成的SQL语句更具有实用性。

对 ChatGLM-6B 进行微调的开源项目:
* [InstructGLM](https://github.com/yanqiangmiffy/InstructGLM):基于ChatGLM-6B进行指令学习,汇总开源中英文指令数据,基于Lora进行指令数据微调,开放了Alpaca、Belle微调后的Lora权重,修复web_demo重复问题
* [ChatGLM-Efficient-Tuning](https://github.com/hiyouga/ChatGLM-Efficient-Tuning):实现了ChatGLM-6B模型的监督微调和完整RLHF训练,汇总10余种指令数据集和3种微调方案,实现了4/8比特量化和模型权重融合,提供微调模型快速部署方法。
* [ChatGLM-Finetuning](https://github.com/liucongg/ChatGLM-Finetuning):基于ChatGLM-6B模型,进行下游具体任务微调,涉及Freeze、Lora、P-tuning等,并进行实验效果对比。
* [ChatGLM-Tuning](https://github.com/mymusise/ChatGLM-Tuning): 基于 LoRA 对 ChatGLM-6B 进行微调。类似的项目还包括 [Humanable ChatGLM/GPT Fine-tuning | ChatGLM 微调](https://github.com/hscspring/hcgf)


针对 ChatGLM-6B 的教程/文档:
* [Windows部署文档](https://github.com/ZhangErling/ChatGLM-6B/blob/main/deployment_windows.md)
* [搭建深度学习docker容器以运行 ChatGLM-6B - Luck_zy](https://www.luckzym.com/tags/ChatGLM-6B/)

如果你有其他好的项目/教程的话,欢迎参照上述格式添加到 README 中并提出 [Pull Request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork)。



================================================
FILE: README.md
================================================
# ChatGLM-6B

<p align="center">
   🌐 <a href="https://chatglm.cn/blog" target="_blank">Blog</a> • 🤗 <a href="https://huggingface.co/THUDM/chatglm-6b" target="_blank">HF Repo</a> • 🐦 <a href="https://twitter.com/thukeg" target="_blank">Twitter</a> • 📄<a href="https://arxiv.org/pdf/2406.12793" target="_blank"> Report </a> <br>
</p>
<p align="center">
    👋 加入我们的  <a href="https://discord.gg/fK2dz4bg" target="_blank">Discord</a> 和 <a href="resources/WECHAT.md" target="_blank">WeChat</a>
</p>
<p align="center">
📍在 <a href="https://open.bigmodel.cn/?utm_campaign=open&_channel_track_key=OWTVNma9">智谱AI开放平台</a> 体验和使用更大规模的 GLM 商业模型。
</p>

*Read this in [English](README_en.md).*

## GLM-4 开源模型和API

我们已经发布最新的 **GLM-4** 大语言对话模型,该模型在多个指标上有了新的突破,您可以在以下两个渠道体验我们的最新模型。

+ [GLM-4 开源模型](https://github.com/THUDM/GLM-4) 我们已经开源了 GLM-4-9B 系列模型,在各项指标的ce是上有明显提升,欢迎尝试。
+ [智谱清言](https://chatglm.cn/main/detail?fr=ecology_x) 体验最新版 GLM-4,包括 **GLMs,All tools**等功能。
+ [API平台](https://open.bigmodel.cn/?utm_campaign=open&_channel_track_key=OWTVNma9) 新一代 API 平台已经上线,您可以直接在
  API
  平台上体验 `GLM-4-0520`、`GLM-4-air`、`GLM-4-airx`、`GLM-4-flash`、`GLM-4`、`GLM-3-Turbo`、`CharacterGLM-3`,`CogView-3`
  等新模型。
  其中`GLM-4`、`GLM-3-Turbo`两个模型支持了 `System Prompt`、`Function Call`、 `Retrieval`、`Web_Search`等新功能,欢迎体验。

+ [GLM-4 API 开源教程](https://github.com/MetaGLM/glm-cookbook/) GLM-4 API教程和基础应用,欢迎尝试。
  API相关问题可以在本开源教程疑问,或者使用 [GLM-4 API AI助手](https://open.bigmodel.cn/shareapp/v1/?share_code=sQwt5qyqYVaNh1O_87p8O)
  来获得常见问题的帮助。

-----
## 介绍

ChatGLM-6B 是一个开源的、支持中英双语的对话语言模型,基于 [General Language Model (GLM)](https://github.com/THUDM/GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存)。
ChatGLM-6B 使用了和 ChatGPT 相似的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答,更多信息请参考我们的[博客](https://chatglm.cn/blog)。欢迎通过 [chatglm.cn](https://chatglm.cn) 体验更大规模的 ChatGLM 模型。

为了方便下游开发者针对自己的应用场景定制模型,我们同时实现了基于 [P-Tuning v2](https://github.com/THUDM/P-tuning-v2) 的高效参数微调方法 [(使用指南)](ptuning/README.md) ,INT4 量化级别下最低只需 7GB 显存即可启动微调。

ChatGLM-6B 权重对学术研究**完全开放**,在填写[问卷](https://open.bigmodel.cn/mla/form)进行登记后**亦允许免费商业使用**。

-----

ChatGLM-6B 开源模型旨在与开源社区一起推动大模型技术发展,恳请开发者和大家遵守[开源协议](MODEL_LICENSE),勿将开源模型和代码及基于开源项目产生的衍生物用于任何可能给国家和社会带来危害的用途以及用于任何未经过安全评估和备案的服务。**目前,本项目团队未基于 ChatGLM-6B 开发任何应用,包括网页端、安卓、苹果 iOS 及 Windows App 等应用。**

尽管模型在训练的各个阶段都尽力确保数据的合规性和准确性,但由于 ChatGLM-6B 模型规模较小,且模型受概率随机性因素影响,无法保证输出内容的准确性,且模型易被误导(详见[局限性](README.md#局限性))。**本项目不承担开源模型和代码导致的数据安全、舆情风险或发生任何模型被误导、滥用、传播、不当利用而产生的风险和责任。**

## 更新信息
**[2023/07/25]** 发布 [CodeGeeX2](https://github.com/THUDM/CodeGeeX2) ,基于 ChatGLM2-6B 的代码生成模型,代码能力全面提升,更多特性包括:

* **更强大的代码能力**:CodeGeeX2-6B 进一步经过了 600B 代码数据预训练,相比 CodeGeeX 一代模型,在代码能力上全面提升,[HumanEval-X](https://huggingface.co/datasets/THUDM/humaneval-x) 评测集的六种编程语言均大幅提升 (Python +57%, C++ +71%, Java +54%, JavaScript +83%, Go +56%, Rust +321\%),在Python上达到 35.9\% 的 Pass@1 一次通过率,超越规模更大的 StarCoder-15B。
* **更优秀的模型特性**:继承 ChatGLM2-6B 模型特性,CodeGeeX2-6B 更好支持中英文输入,支持最大 8192 序列长度,推理速度较一代 大幅提升,量化后仅需6GB显存即可运行,支持轻量级本地化部署。
* **更全面的AI编程助手**:CodeGeeX插件([VS Code](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex), [Jetbrains](https://plugins.jetbrains.com/plugin/20587-codegeex))后端升级,支持超过100种编程语言,新增上下文补全、跨文件补全等实用功能。结合 Ask CodeGeeX 交互式AI编程助手,支持中英文对话解决各种编程问题,包括且不限于代码解释、代码翻译、代码纠错、文档生成等,帮助程序员更高效开发。

**[2023/06/25]** 发布 [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B),ChatGLM-6B 的升级版本,在保留了了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,ChatGLM**2**-6B 引入了如下新特性:

1. **更强大的性能**:基于 ChatGLM 初代模型的开发经验,我们全面升级了 ChatGLM2-6B 的基座模型。ChatGLM2-6B 使用了 [GLM](https://github.com/THUDM/GLM) 的混合目标函数,经过了 1.4T 中英标识符的预训练与人类偏好对齐训练,[评测结果](#评测结果)显示,相比于初代模型,ChatGLM2-6B 在 MMLU(+23%)、CEval(+33%)、GSM8K(+571%) 、BBH(+60%)等数据集上的性能取得了大幅度的提升,在同尺寸开源模型中具有较强的竞争力。
2. **更长的上下文**:基于 [FlashAttention](https://github.com/HazyResearch/flash-attention) 技术,我们将基座模型的上下文长度(Context Length)由 ChatGLM-6B 的 2K 扩展到了 32K,并在对话阶段使用 8K 的上下文长度训练,允许更多轮次的对话。但当前版本的 ChatGLM2-6B 对单轮超长文档的理解能力有限,我们会在后续迭代升级中着重进行优化。
3. **更高效的推理**:基于 [Multi-Query Attention](http://arxiv.org/abs/1911.02150) 技术,ChatGLM2-6B 有更高效的推理速度和更低的显存占用:在官方的模型实现下,推理速度相比初代提升了 42%,INT4 量化下,6G 显存支持的对话长度由 1K 提升到了 8K。

更多信息参见 [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B)。

**[2023/06/14]** 发布 [WebGLM](https://github.com/THUDM/WebGLM),一项被接受于KDD 2023的研究工作,支持利用网络信息生成带有准确引用的长回答。

![](resources/webglm.jpg)

**[2023/05/17]** 发布 [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B),一个支持图像理解的多模态对话语言模型。

![](resources/visualglm.png)

可以通过本仓库中的 [cli_demo_vision.py](cli_demo_vision.py) 和 [web_demo_vision.py](web_demo_vision.py) 来运行命令行和网页 Demo。注意 VisualGLM-6B 需要额外安装 [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer/) 和 torchvision。更多信息参见 [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B)。

**[2023/05/15]** 更新 v1.1 版本 checkpoint,训练数据增加英文指令微调数据以平衡中英文数据比例,解决英文回答中夹杂中文词语的现象。

<details><summary><b>以下是更新前后的英文问题对比:</b></summary>

* 问题:Describe a time when you had to make a difficult decision.
  - v1.0:
  ![](resources/english-q1-old.png)
  - v1.1:
  ![](resources/english-q1-new.png)
* 问题:Describe the function of a computer motherboard
  - v1.0:
  ![](resources/english-q2-old.png)
  - v1.1:
  ![](resources/english-q2-new.png)
* 问题:Develop a plan to reduce electricity usage in a home.
  - v1.0:
  ![](resources/english-q3-old.png)
  - v1.1:
  ![](resources/english-q3-new.png)
* 问题:未来的NFT,可能真实定义一种现实的资产,它会是一处房产,一辆汽车,一片土地等等,这样的数字凭证可能比真实的东西更有价值,你可以随时交易和使用,在虚拟和现实中无缝的让拥有的资产继续创造价值,未来会是万物归我所用,但不归我所有的时代。翻译成专业的英语
  - v1.0:
  ![](resources/english-q4-old.png)
  - v1.1:
  ![](resources/english-q4-new.png)
</details>

更多更新信息参见 [UPDATE.md](UPDATE.md)

## 友情链接
对 ChatGLM 进行加速的开源项目:
* [lyraChatGLM](https://huggingface.co/TMElyralab/lyraChatGLM): 对 ChatGLM-6B 进行推理加速,最高可以实现 9000+ tokens/s 的推理速度
* [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): 一个基于 MNN 的 ChatGLM-6B C++ 推理实现,支持根据显存大小自动分配计算任务给 GPU 和 CPU
* [JittorLLMs](https://github.com/Jittor/JittorLLMs):最低3G显存或者没有显卡都可运行 ChatGLM-6B FP16, 支持Linux、windows、Mac部署
* [InferLLM](https://github.com/MegEngine/InferLLM):轻量级 C++ 推理,可以实现本地 x86,Arm 处理器上实时聊天,手机上也同样可以实时运行,运行内存只需要 4G

基于或使用了 ChatGLM-6B 的开源项目:
* [langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM):基于 langchain 的 ChatGLM 应用,实现基于可扩展知识库的问答
* [闻达](https://github.com/l15y/wenda):大型语言模型调用平台,基于 ChatGLM-6B 实现了类 ChatPDF 功能
* [glm-bot](https://github.com/initialencounter/glm-bot):将ChatGLM接入Koishi可在各大聊天平台上调用ChatGLM
* [Chuanhu Chat](https://github.com/GaiZhenbiao/ChuanhuChatGPT): 为各个大语言模型和在线模型API提供美观易用、功能丰富、快速部署的用户界面,支持ChatGLM-6B。

支持 ChatGLM-6B 和相关应用在线训练的示例项目:
* [ChatGLM-6B 的部署与微调教程](https://www.heywhale.com/mw/project/6436d82948f7da1fee2be59e)
* [ChatGLM-6B 结合 langchain 实现本地知识库 QA Bot](https://www.heywhale.com/mw/project/643977aa446c45f4592a1e59)

第三方评测:
* [Measuring Massive Multitask Chinese Understanding](https://arxiv.org/abs/2304.12986)

更多开源项目参见 [PROJECT.md](PROJECT.md)

## 使用方式

### 硬件需求

| **量化等级**   | **最低 GPU 显存**(推理) | **最低 GPU 显存**(高效参数微调) |
| -------------- | ------------------------- | --------------------------------- |
| FP16(无量化) | 13 GB                     | 14 GB                             |
| INT8           | 8 GB                     | 9 GB                             |
| INT4           | 6 GB                      | 7 GB                              |
### 环境安装

使用 pip 安装依赖:`pip install -r requirements.txt`,其中 `transformers` 库版本推荐为 `4.27.1`,但理论上不低于 `4.23.1` 即可。

此外,如果需要在 cpu 上运行量化后的模型,还需要安装 `gcc` 与 `openmp`。多数 Linux 发行版默认已安装。对于 Windows ,可在安装 [TDM-GCC](https://jmeubank.github.io/tdm-gcc/) 时勾选 `openmp`。 Windows 测试环境 `gcc` 版本为 `TDM-GCC 10.3.0`, Linux 为 `gcc 11.3.0`。在 MacOS 上请参考 [Q1](FAQ.md#q1)。

### 代码调用 

可以通过如下代码调用 ChatGLM-6B 模型来生成对话:

```python
>>> from transformers import AutoTokenizer, AutoModel
>>> tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
>>> model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
>>> model = model.eval()
>>> response, history = model.chat(tokenizer, "你好", history=[])
>>> print(response)
你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。
>>> response, history = model.chat(tokenizer, "晚上睡不着应该怎么办", history=history)
>>> print(response)
晚上睡不着可能会让你感到焦虑或不舒服,但以下是一些可以帮助你入睡的方法:

1. 制定规律的睡眠时间表:保持规律的睡眠时间表可以帮助你建立健康的睡眠习惯,使你更容易入睡。尽量在每天的相同时间上床,并在同一时间起床。
2. 创造一个舒适的睡眠环境:确保睡眠环境舒适,安静,黑暗且温度适宜。可以使用舒适的床上用品,并保持房间通风。
3. 放松身心:在睡前做些放松的活动,例如泡个热水澡,听些轻柔的音乐,阅读一些有趣的书籍等,有助于缓解紧张和焦虑,使你更容易入睡。
4. 避免饮用含有咖啡因的饮料:咖啡因是一种刺激性物质,会影响你的睡眠质量。尽量避免在睡前饮用含有咖啡因的饮料,例如咖啡,茶和可乐。
5. 避免在床上做与睡眠无关的事情:在床上做些与睡眠无关的事情,例如看电影,玩游戏或工作等,可能会干扰你的睡眠。
6. 尝试呼吸技巧:深呼吸是一种放松技巧,可以帮助你缓解紧张和焦虑,使你更容易入睡。试着慢慢吸气,保持几秒钟,然后缓慢呼气。

如果这些方法无法帮助你入睡,你可以考虑咨询医生或睡眠专家,寻求进一步的建议。
```
模型的实现仍然处在变动中。如果希望固定使用的模型实现以保证兼容性,可以在 `from_pretrained` 的调用中增加 `revision="v1.1.0"` 参数。`v1.1.0` 是当前最新的版本号,完整的版本列表参见 [Change Log](https://huggingface.co/THUDM/chatglm-6b#change-log)。

### 从本地加载模型
以上代码会由 `transformers` 自动下载模型实现和参数。完整的模型实现可以在 [Hugging Face Hub](https://huggingface.co/THUDM/chatglm-6b)。如果你的网络环境较差,下载模型参数可能会花费较长时间甚至失败。此时可以先将模型下载到本地,然后从本地加载。

从 Hugging Face Hub 下载模型需要先[安装Git LFS](https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large-file-storage),然后运行
```Shell
git clone https://huggingface.co/THUDM/chatglm-6b
```

如果你从 Hugging Face Hub 上下载 checkpoint 的速度较慢,可以只下载模型实现
```Shell
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/THUDM/chatglm-6b
```
然后从[这里](https://cloud.tsinghua.edu.cn/d/fb9f16d6dc8f482596c2/)手动下载模型参数文件,并将下载的文件替换到本地的 `chatglm-6b` 目录下。

将模型下载到本地之后,将以上代码中的 `THUDM/chatglm-6b` 替换为你本地的 `chatglm-6b` 文件夹的路径,即可从本地加载模型。

**Optional** 模型的实现仍然处在变动中。如果希望固定使用的模型实现以保证兼容性,可以执行
```Shell
git checkout v1.1.0
```

## Demo & API

我们提供了一个基于 [Gradio](https://gradio.app) 的网页版 Demo 和一个命令行 Demo。使用时首先需要下载本仓库:

```shell
git clone https://github.com/THUDM/ChatGLM-6B
cd ChatGLM-6B
```

### 网页版 Demo

![web-demo](resources/web-demo.gif)

首先安装 Gradio:`pip install gradio`,然后运行仓库中的 [web_demo.py](web_demo.py): 

```shell
python web_demo.py
```

程序会运行一个 Web Server,并输出地址。在浏览器中打开输出的地址即可使用。最新版 Demo 实现了打字机效果,速度体验大大提升。注意,由于国内 Gradio 的网络访问较为缓慢,启用 `demo.queue().launch(share=True, inbrowser=True)` 时所有网络会经过 Gradio 服务器转发,导致打字机体验大幅下降,现在默认启动方式已经改为 `share=False`,如有需要公网访问的需求,可以重新修改为 `share=True` 启动。

感谢 [@AdamBear](https://github.com/AdamBear) 实现了基于 Streamlit 的网页版 Demo,运行方式见[#117](https://github.com/THUDM/ChatGLM-6B/pull/117).

### 命令行 Demo

![cli-demo](resources/cli-demo.png)

运行仓库中 [cli_demo.py](cli_demo.py):

```shell
python cli_demo.py
```

程序会在命令行中进行交互式的对话,在命令行中输入指示并回车即可生成回复,输入 `clear` 可以清空对话历史,输入 `stop` 终止程序。

### API部署
首先需要安装额外的依赖 `pip install fastapi uvicorn`,然后运行仓库中的 [api.py](api.py):
```shell
python api.py
```
默认部署在本地的 8000 端口,通过 POST 方法进行调用
```shell
curl -X POST "http://127.0.0.1:8000" \
     -H 'Content-Type: application/json' \
     -d '{"prompt": "你好", "history": []}'
```
得到的返回值为
```shell
{
  "response":"你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。",
  "history":[["你好","你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。"]],
  "status":200,
  "time":"2023-03-23 21:38:40"
}
```

## 低成本部署
### 模型量化
默认情况下,模型以 FP16 精度加载,运行上述代码需要大概 13GB 显存。如果你的 GPU 显存有限,可以尝试以量化方式加载模型,使用方法如下:

```python
# 按需修改,目前只支持 4/8 bit 量化
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).quantize(8).half().cuda()
```

进行 2 至 3 轮对话后,8-bit 量化下 GPU 显存占用约为 10GB,4-bit 量化下仅需 6GB 占用。随着对话轮数的增多,对应消耗显存也随之增长,由于采用了相对位置编码,理论上 ChatGLM-6B 支持无限长的 context-length,但总长度超过 2048(训练长度)后性能会逐渐下降。

模型量化会带来一定的性能损失,经过测试,ChatGLM-6B 在 4-bit 量化下仍然能够进行自然流畅的生成。使用 [GPT-Q](https://arxiv.org/abs/2210.17323) 等量化方案可以进一步压缩量化精度/提升相同量化精度下的模型性能,欢迎大家提出对应的 Pull Request。

量化过程需要在内存中首先加载 FP16 格式的模型,消耗大概 13GB 的内存。如果你的内存不足的话,可以直接加载量化后的模型,INT4 量化后的模型仅需大概 5.2GB 的内存:
```python
# INT8 量化的模型将"THUDM/chatglm-6b-int4"改为"THUDM/chatglm-6b-int8"
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()
```
量化模型的参数文件也可以从[这里](https://cloud.tsinghua.edu.cn/d/674208019e314311ab5c/)手动下载。

### CPU 部署
如果你没有 GPU 硬件的话,也可以在 CPU 上进行推理,但是推理速度会更慢。使用方法如下(需要大概 32GB 内存)
```python
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).float()
```

如果你的内存不足,可以直接加载量化后的模型:
```python
# INT8 量化的模型将"THUDM/chatglm-6b-int4"改为"THUDM/chatglm-6b-int8"
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4",trust_remote_code=True).float()
```

如果遇到了报错 `Could not find module 'nvcuda.dll'` 或者 `RuntimeError: Unknown platform: darwin` (MacOS) ,请[从本地加载模型](README.md#从本地加载模型)

### Mac 部署
对于搭载了 Apple Silicon 或者 AMD GPU 的Mac,可以使用 MPS 后端来在 GPU 上运行 ChatGLM-6B。需要参考 Apple 的 [官方说明](https://developer.apple.com/metal/pytorch) 安装 PyTorch-Nightly(正确的版本号应该是2.1.0.dev2023xxxx,而不是2.0.0)。

目前在 MacOS 上只支持[从本地加载模型](README.md#从本地加载模型)。将代码中的模型加载改为从本地加载,并使用 mps 后端:
```python
model = AutoModel.from_pretrained("your local path", trust_remote_code=True).half().to('mps')
```

加载半精度的 ChatGLM-6B 模型需要大概 13GB 内存。内存较小的机器(比如 16GB 内存的 MacBook Pro),在空余内存不足的情况下会使用硬盘上的虚拟内存,导致推理速度严重变慢。此时可以使用量化后的模型如 chatglm-6b-int4。因为 GPU 上量化的 kernel 是使用 CUDA 编写的,因此无法在 MacOS 上使用,只能使用 CPU 进行推理。
```python
# INT8 量化的模型将"THUDM/chatglm-6b-int4"改为"THUDM/chatglm-6b-int8"
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4",trust_remote_code=True).float()
```
为了充分使用 CPU 并行,还需要[单独安装 OpenMP](FAQ.md#q1)。

### 多卡部署
如果你有多张 GPU,但是每张 GPU 的显存大小都不足以容纳完整的模型,那么可以将模型切分在多张GPU上。首先安装 accelerate: `pip install accelerate`,然后通过如下方法加载模型:
```python
from utils import load_model_on_gpus
model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2)
```
即可将模型部署到两张 GPU 上进行推理。你可以将 `num_gpus` 改为你希望使用的 GPU 数。默认是均匀切分的,你也可以传入 `device_map` 参数来自己指定。 

## 高效参数微调
基于 [P-tuning v2](https://github.com/THUDM/P-tuning-v2) 的高效参数微调。具体使用方法详见 [ptuning/README.md](ptuning/README.md)。

## ChatGLM-6B 示例

以下是一些使用 `web_demo.py` 得到的示例截图。更多 ChatGLM-6B 的可能,等待你来探索发现!

<details><summary><b>自我认知</b></summary>

![](examples/self-introduction.png)

</details>

<details><summary><b>提纲写作</b></summary>

![](examples/blog-outline.png)

</details>

<details><summary><b>文案写作</b></summary>

![](examples/ad-writing-2.png)

![](examples/comments-writing.png)

</details>

<details><summary><b>邮件写作助手</b></summary>

![](examples/email-writing-1.png)

![](examples/email-writing-2.png)

</details>

<details><summary><b>信息抽取</b></summary>

![](examples/information-extraction.png)

</details>

<details><summary><b>角色扮演</b></summary>

![](examples/role-play.png)

</details>

<details><summary><b>评论比较</b></summary>

![](examples/sport.png)

</details>

<details><summary><b>旅游向导</b></summary>

![](examples/tour-guide.png)

</details>

## 局限性

由于 ChatGLM-6B 的小规模,其能力仍然有许多局限性。以下是我们目前发现的一些问题:

- 模型容量较小:6B 的小容量,决定了其相对较弱的模型记忆和语言能力。在面对许多事实性知识任务时,ChatGLM-6B 可能会生成不正确的信息;它也不擅长逻辑类问题(如数学、编程)的解答。
    <details><summary><b>点击查看例子</b></summary>
    
    ![](limitations/factual_error.png)
    
    ![](limitations/math_error.png)
    
    </details>
  
- 产生有害说明或有偏见的内容:ChatGLM-6B 只是一个初步与人类意图对齐的语言模型,可能会生成有害、有偏见的内容。(内容可能具有冒犯性,此处不展示)

- 英文能力不足:ChatGLM-6B 训练时使用的指示/回答大部分都是中文的,仅有极小一部分英文内容。因此,如果输入英文指示,回复的质量远不如中文,甚至与中文指示下的内容矛盾,并且出现中英夹杂的情况。

- 易被误导,对话能力较弱:ChatGLM-6B 对话能力还比较弱,而且 “自我认知” 存在问题,并很容易被误导并产生错误的言论。例如当前版本的模型在被误导的情况下,会在自我认知上发生偏差。
    <details><summary><b>点击查看例子</b></summary>

    ![](limitations/self-confusion_google.jpg)
    
    ![](limitations/self-confusion_openai.jpg)
    
    ![](limitations/self-confusion_tencent.jpg)
    
    </details>

## 协议

本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源,ChatGLM-6B 模型的权重的使用则需要遵循 [Model License](MODEL_LICENSE)。ChatGLM-6B 权重对学术研究**完全开放**,在填写[问卷](https://open.bigmodel.cn/mla/form)进行登记后**亦允许免费商业使用**。

## 引用

如果你觉得我们的工作有帮助的话,请考虑引用下列论文

```
@misc{glm2024chatglm,
      title={ChatGLM: A Family of Large Language Models from GLM-130B to GLM-4 All Tools}, 
      author={Team GLM and Aohan Zeng and Bin Xu and Bowen Wang and Chenhui Zhang and Da Yin and Diego Rojas and Guanyu Feng and Hanlin Zhao and Hanyu Lai and Hao Yu and Hongning Wang and Jiadai Sun and Jiajie Zhang and Jiale Cheng and Jiayi Gui and Jie Tang and Jing Zhang and Juanzi Li and Lei Zhao and Lindong Wu and Lucen Zhong and Mingdao Liu and Minlie Huang and Peng Zhang and Qinkai Zheng and Rui Lu and Shuaiqi Duan and Shudan Zhang and Shulin Cao and Shuxun Yang and Weng Lam Tam and Wenyi Zhao and Xiao Liu and Xiao Xia and Xiaohan Zhang and Xiaotao Gu and Xin Lv and Xinghan Liu and Xinyi Liu and Xinyue Yang and Xixuan Song and Xunkai Zhang and Yifan An and Yifan Xu and Yilin Niu and Yuantao Yang and Yueyan Li and Yushi Bai and Yuxiao Dong and Zehan Qi and Zhaoyu Wang and Zhen Yang and Zhengxiao Du and Zhenyu Hou and Zihan Wang},
      year={2024},
      eprint={2406.12793},
      archivePrefix={arXiv},
      primaryClass={id='cs.CL' full_name='Computation and Language' is_active=True alt_name='cmp-lg' in_archive='cs' is_general=False description='Covers natural language processing. Roughly includes material in ACM Subject Class I.2.7. Note that work on artificial languages (programming languages, logics, formal systems) that does not explicitly address natural-language issues broadly construed (natural-language processing, computational linguistics, speech, text retrieval, etc.) is not appropriate for this area.'}
}
```


================================================
FILE: README_en.md
================================================
# ChatGLM-6B

<p align="center">
   🌐 <a href="https://chatglm.cn/blog" target="_blank">Blog</a> • 🤗 <a href="https://huggingface.co/THUDM/chatglm-6b" target="_blank">HF Repo</a> • 🐦 <a href="https://twitter.com/thukeg" target="_blank">Twitter</a> • 📄<a href="https://arxiv.org/pdf/2406.12793" target="_blank"> Report </a> <br>
</p>
<p align="center">
    👋 Join our  <a href="https://discord.gg/fK2dz4bg" target="_blank">Discord</a> and <a href="resources/WECHAT.md" target="_blank">WeChat</a>
</p>
<p align="center">
📍Experience and use a larger-scale GLM business model on the <a href="https://open.bigmodel.cn/?utm_campaign=open&_channel_track_key=OWTVNma9">Zhipu AI Open Platform</a>
</p>

## GLM-4 Open Source Model and API

We have released the latest **GLM-4** model, which has made new breakthroughs in multiple indicators. You can directly
experience our latest model in the following two channels.

+ [GLM-4 open source model](https://github.com/THUDM/GLM-4) We have open sourced the GLM-4-9B series models, which have
  significantly improved the performance of various indicators. Welcome to try.
+ [Zhipu Qingyan](https://chatglm.cn/main/detail?fr=ecology_x) Experience the latest version of GLM-4, including **GLMs,
  All tools** and other functions.
+ [API platform](https://open.bigmodel.cn/?utm_campaign=open&_channel_track_key=OWTVNma9) The new generation of API
  platform has been launched. You can directly experience new models such
  as `GLM-4-0520`, `GLM-4-air`, `GLM-4-airx`, `GLM-4-flash`, `GLM-4`, `GLM-3-Turbo`, `CharacterGLM-3`, `CogView-3` on
  the API platform.
  Among them, the two models `GLM-4` and `GLM-3-Turbo` support new functions such
  as `System Prompt`, `Function Call`, `Retrieval`, and `Web_Search`. You are welcome to experience them.

+ [GLM-4 API open source tutorial](https://github.com/MetaGLM/glm-cookbook/) GLM-4 API tutorial and basic applications,
  welcome to try.
  API-related questions can be asked in this open source tutorial, or
  use [GLM-4 API AI Assistant](https://open.bigmodel.cn/shareapp/v1/?share_code=sQwt5qyqYVaNh1O_87p8O)
  to get help with common problems.


## Introduction

ChatGLM-6B is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework, with 6.2 billion parameters. With the quantization technique, users can deploy locally on consumer-grade graphics cards (only 6GB of GPU memory is required at the INT4 quantization level). Welcome to use the larger ChatGLM model on [chatglm.cn](https://chatglm.cn) 

ChatGLM-6B uses technology similar to ChatGPT, optimized for Chinese QA and dialogue. The model is trained for about 1T tokens of Chinese and English corpus, supplemented by supervised fine-tuning, feedback bootstrap, and reinforcement learning wit human feedback. With only about 6.2 billion parameters, the model is able to generate answers that are in line with human preference.

In order to facilitate downstream developers to customize the model for their own application scenarios, we also implements an parameter-efficient tuning method based on [P-Tuning v2](https://github.com/THUDM/P-tuning-v2)[(Guidelines)](ptuning/README_en.md). Tuning requires at least 7GB of GPU memory at INT4 quantization level.

ChatGLM-6B weights are **completely open** for academic research, and **free commercial use** is also allowed after completing the [questionnaire](https://open.bigmodel.cn/mla/form).

Try the [online demo](https://huggingface.co/spaces/ysharma/ChatGLM-6b_Gradio_Streaming) on Huggingface Spaces.

## Update
**[2023/07/25]** Release [CodeGeeX2](https://github.com/THUDM/CodeGeeX2), which is based on ChatGLM2-6B and trained on more code data. It has the following features:

* **More Powerful Coding Capabilities**: CodeGeeX2-6B has been further pre-trained on 600B code tokens, which has been comprehensively improved in coding capability compared to the first-generation. On the [HumanEval-X](https://huggingface.co/datasets/THUDM/humaneval-x) benchmark, all six languages have been significantly improved (Python +57%, C++ +71%, Java +54%, JavaScript +83%, Go +56%, Rust +321\%), and in Python it reached 35.9% of Pass@1 one-time pass rate, surpassing the larger StarCoder-15B.
* **More Useful Features**: Inheriting the ChatGLM2-6B model features, CodeGeeX2-6B better supports both Chinese and English prompts, maximum 8192 sequence length, and the inference speed is significantly improved compared to the first-generation. After quantization, it only needs 6GB of GPU memory for inference, thus supports lightweight local deployment.
* **Comprehensive AI Coding Assistant**: The backend of CodeGeeX plugin ([VS Code](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex), [Jetbrains](https://plugins.jetbrains.com/plugin/20587-codegeex)) is upgraded, supporting 100+ programming languages, and adding practical functions such as infilling and cross-file completion. Combined with the "Ask CodeGeeX" interactive AI coding assistant, it can be used to solve various programming problems via Chinese or English dialogue, including but not limited to code summarization, code translation, debugging, and comment generation, which helps increasing the efficiency of developpers.

**[2023/06/25]** Release [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B), the second-generation version of ChatGLM-6B. It retains the smooth conversation flow and low deployment threshold of the first-generation model, while introducing the following new features:

1. **Stronger Performance**: Based on the development experience of the first-generation ChatGLM model, we have fully upgraded the base model of ChatGLM2-6B. ChatGLM2-6B uses the hybrid objective function of [GLM](https://github.com/THUDM/GLM), and has undergone pre-training with 1.4T bilingual tokens and human preference alignment training. The [evaluation results](README.md#evaluation-results) show that, compared to the first-generation model, ChatGLM2-6B has achieved substantial improvements in performance on datasets like MMLU (+23%), CEval (+33%), GSM8K (+571%), BBH (+60%), showing strong competitiveness among models of the same size.
2. **Longer Context**: Based on [FlashAttention](https://github.com/HazyResearch/flash-attention) technique, we have extended the context length of the base model from 2K in ChatGLM-6B to 32K, and trained with a context length of 8K during the dialogue alignment, allowing for more rounds of dialogue. However, the current version of ChatGLM2-6B has limited understanding of single-round ultra-long documents, which we will focus on optimizing in future iterations.
3. **More Efficient Inference**: Based on [Multi-Query Attention](http://arxiv.org/abs/1911.02150) technique, ChatGLM2-6B has more efficient inference speed and lower GPU memory usage: under the official  implementation, the inference speed has increased by 42% compared to the first generation; under INT4 quantization, the dialogue length supported by 6G GPU memory has increased from 1K to 8K.

Fore more information, please refer to [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B).

**[2023/05/17]** Release [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B), a multimodal conversational language model supporting image understanding.

![](resources/visualglm.png)

You can run the command line and web demo through [cli_demo_vision.py](cli_demo_vision.py) and [web_demo_vision.py](web_demo_vision.py) in the repository. Note that VisualGLM-6B requires additional installation of [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer/) and torchvision. For more information, please refer to  [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B).

**[2023/05/15]** Update the checkpoint of v1.1 version, add English instruction data for training to balance the proportion of Chinese and English data, which solves the phenomenon of Chinese words mixed in English answers .

<details><summary><b>The following is a comparison of English questions before and after the update</b></summary>

* Question: Describe a time when you had to make a difficult decision.
  - v1.0:
  ![](resources/english-q1-old.png)
  - v1.1:
  ![](resources/english-q1-new.png)
* Question: Describe the function of a computer motherboard
  - v1.0:
  ![](resources/english-q2-old.png)
  - v1.1:
  ![](resources/english-q2-new.png)
* Question: Develop a plan to reduce electricity usage in a home.
  - v1.0:
  ![](resources/english-q3-old.png)
  - v1.1:
  ![](resources/english-q3-new.png)
* Question:未来的NFT,可能真实定义一种现实的资产,它会是一处房产,一辆汽车,一片土地等等,这样的数字凭证可能比真实的东西更有价值,你可以随时交易和使用,在虚拟和现实中无缝的让拥有的资产继续创造价值,未来会是万物归我所用,但不归我所有的时代。翻译成专业的英语
  - v1.0:
  ![](resources/english-q4-old.png)
  - v1.1:
  ![](resources/english-q4-new.png)
</details>

For more update info, please refer to [UPDATE.md](UPDATE.md).

## Projects
Open source projects that accelerate ChatGLM:
* [lyraChatGLM](https://huggingface.co/TMElyralab/lyraChatGLM): Inference acceleration for ChatGLM-6B, up to 9000+ tokens/s inference speed.
* [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): An MNN-based implementation of ChatGLM-6B C++ inference, which supports automatic allocation of computing tasks to GPU and CPU according to the size of GPU memory
* [JittorLLMs](https://github.com/Jittor/JittorLLMs): Running ChatGLM-6B in FP16 with a minimum of 3G GPU memory or no GPU at all, with Linux, windows, and Mac support
* [InferLLM](https://github.com/MegEngine/InferLLM): Lightweight C++ inference, which can realize real-time chat on local x86 and Arm processors, and can also run in real time on mobile phones. It only requires 4G of running memory.

Open source projects using ChatGLM-6B:
* [langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM): ChatGLM application based on langchain, realizing Q&A based on extensible knowledge base
* [Wenda](https://github.com/l15y/wenda): Large-scale language model call platform, based on ChatGLM-6B to achieve ChatPDF-like functions
* [chatgpt_academic](https://github.com/binary-husky/chatgpt_academic): An academic writing and programming toolbox that supports ChatGLM-6B. It has the characteristics of modularization and multi-thread calling LLM, and can call multiple LLMs in parallel.
* [glm-bot](https://github.com/initialencounter/glm-bot): Connect ChatGLM to Koishi to call ChatGLM on major chat platforms

Example projects supporting online training of ChatGLM-6B and related applications:
* [ChatGLM-6B deployment and fine-tuning tutorial](https://www.heywhale.com/mw/project/6436d82948f7da1fee2be59e)
* [ChatGLM-6B combined with langchain to implement local knowledge base QA Bot](https://www.heywhale.com/mw/project/643977aa446c45f4592a1e59)

Third-party evaluation:
* [Measuring Massive Multitask Chinese Understanding](https://arxiv.org/abs/2304.12986)

For more open source projects, see [PROJECT.md](PROJECT.md).

## Getting Started

### Hardware Requirements

| **Quantization Level** | **GPU Memory** |
|------------------------|----------------|
| FP16(no quantization)  | 13 GB          |
| INT8                   | 10 GB          |
| INT4                   | 6 GB           |

### Environment Setup

Install the requirements with pip: `pip install -r requirements.txt`. `transformers` library version is recommended to be `4.27.1`, but theoretically any version no lower than `4.23.1` is acceptable.

In addition, if you need to run the quantified model on the CPU, you also need to install `gcc` and `openmp`. Most Linux distributions are installed by default. For Windows, you can check `openmp` when installing [TDM-GCC](https://jmeubank.github.io/tdm-gcc/). On Windows testing environment, the `gcc` version is `TDM-GCC 10.3.0`, and on Linux is `gcc 11.3.0`.

### Usage

Generate dialogue with the following code

```python
>>> from transformers import AutoTokenizer, AutoModel
>>> tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
>>> model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
>>> model = model.eval()
>>> response, history = model.chat(tokenizer, "你好", history=[])
>>> print(response)
你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。
>>> response, history = model.chat(tokenizer, "晚上睡不着应该怎么办", history=history)
>>> print(response)
晚上睡不着可能会让你感到焦虑或不舒服,但以下是一些可以帮助你入睡的方法:

1. 制定规律的睡眠时间表:保持规律的睡眠时间表可以帮助你建立健康的睡眠习惯,使你更容易入睡。尽量在每天的相同时间上床,并在同一时间起床。
2. 创造一个舒适的睡眠环境:确保睡眠环境舒适,安静,黑暗且温度适宜。可以使用舒适的床上用品,并保持房间通风。
3. 放松身心:在睡前做些放松的活动,例如泡个热水澡,听些轻柔的音乐,阅读一些有趣的书籍等,有助于缓解紧张和焦虑,使你更容易入睡。
4. 避免饮用含有咖啡因的饮料:咖啡因是一种刺激性物质,会影响你的睡眠质量。尽量避免在睡前饮用含有咖啡因的饮料,例如咖啡,茶和可乐。
5. 避免在床上做与睡眠无关的事情:在床上做些与睡眠无关的事情,例如看电影,玩游戏或工作等,可能会干扰你的睡眠。
6. 尝试呼吸技巧:深呼吸是一种放松技巧,可以帮助你缓解紧张和焦虑,使你更容易入睡。试着慢慢吸气,保持几秒钟,然后缓慢呼气。

如果这些方法无法帮助你入睡,你可以考虑咨询医生或睡眠专家,寻求进一步的建议。
```
The implementation of the model is still in development. If you want to fix the used model implementation to ensure compatibility, you can add the `revision="v1.1.0"` parameter in the `from_pretrained` call. `v1.1.0` is the latest version number. For a complete list of versions, see [Change Log](https://huggingface.co/THUDM/chatglm-6b#change-log).

### Load the model locally
The above code will automatically download the model implementation and checkpoints by [transformers](https://github.com/huggingface/transformers). The full model implementation can be found at [Hugging Face Hub](https://huggingface.co/THUDM/chatglm-6b). If your network environment is poor, downloading model parameters may take a long time or even fail. At this point, you can download the model to the local first, and then load it from the local.

To download models from Hugging Face Hub, you need to [install Git LFS](https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large-file-storage) , then run
```Shell
git clone https://huggingface.co/THUDM/chatglm-6b
```

After downloading the model locally, replace `THUDM/chatglm-6b` in the above code with the path of your local `chatglm-6b` folder to load the model locally.

**Optional**: The implementation of the model is still in development. If you want to fix the used model implementation to ensure compatibility, you can execute
```Shell
git checkout v1.1.0
```

## Demo & API

We provide a Web demo based on [Gradio](https://gradio.app) and a command line demo in the repo. First clone our repo with:

```shell
git clone https://github.com/THUDM/ChatGLM-6B
cd ChatGLM-6B
```

### Web Demo

![web-demo](resources/web-demo.gif)

Install Gradio `pip install gradio`,and run [web_demo.py](web_demo.py):

```shell
python web_demo.py
```

The program runs a web server and outputs the URL. Open the URL in the browser to use the web demo.

Thanks to [@AdamBear](https://github.com/AdamBear) for implementing a web demo based on Streamlit, see [#117](https://github.com/THUDM/ChatGLM-6B/pull/117 ).

#### CLI Demo

![cli-demo](resources/cli-demo.png)

Run [cli_demo.py](cli_demo.py) in the repo:

```shell
python cli_demo.py
```

The command runs an interactive program in the shell. Type your instruction in the shell and hit enter to generate the response. Type `clear` to clear the dialogue history and `stop` to terminate the program.

## API Deployment
First install the additional dependency `pip install fastapi uvicorn`. The run [api.py](api.py) in the repo.
```shell
python api.py
```
By default the api runs at the`8000`port of the local machine. You can call the API via 
```shell
curl -X POST "http://127.0.0.1:8000" \
     -H 'Content-Type: application/json' \
     -d '{"prompt": "你好", "history": []}'
```
The returned value is
```shell
{
  "response":"你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。",
  "history":[["你好","你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。"]],
  "status":200,
  "time":"2023-03-23 21:38:40"
}
```

## Deployment

### Quantization

By default, the model parameters are loaded with FP16 precision, which require about 13GB of GPU memory. It your GPU memory is limited, you can try to load the model parameters with quantization:

```python
# Change according to your hardware. Only support 4/8 bit quantization now.
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().quantize(8).cuda()
```

After 2 to 3 rounds of dialogue, the GPU memory usage is about 10GB under 8-bit quantization, and only 6GB under 4-bit quantization. As the number of dialogue rounds increases, the corresponding GPU memory consumption also increases. Due to the use of relative position encoding, ChatGLM-6B theoretically supports an infinitely long context-length, but the performance will gradually decline after the total length exceeds 2048 (training length).

Model quantization brings a certain performance decline. After testing, ChatGLM-6B can still perform natural and smooth generation under 4-bit quantization. using [GPT-Q](https://arxiv.org/abs/2210.17323) etc. The quantization scheme can further compress the quantization accuracy/improve the model performance under the same quantization accuracy. You are welcome to submit corresponding Pull Requests.

The quantization costs about 13GB of CPU memory to load the FP16 model. If your CPU memory is limited, you can directly load the quantized model, which costs only 5.2GB CPU memory: 
```python
# For INT8-quantized model, change "chatglm-6b-int4" to "chatglm-6b-int8"
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()
```

### CPU Deployment

If your computer is not equipped with GPU, you can also conduct inference on CPU, but the inference speed is slow (and taking about 32GB of memory):

```python
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).float()
```

If your CPU memory is limited, you can directly load the quantized model:
```python
# For INT8-quantized model, change "chatglm-6b-int4" to "chatglm-6b-int8"
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).float()
```

If your encounter the error `Could not find module 'nvcuda.dll'` or `RuntimeError: Unknown platform: darwin`(MacOS), please [load the model locally](README_en.md#load-the-model-locally). 

### Inference on Mac
For Macs (and MacBooks) with Apple Silicon, it is possible to use the MPS backend to run ChatGLM-6B on the GPU. First, you need to refer to Apple's [official instructions](https://developer.apple.com/metal/pytorch) to install PyTorch-Nightly. (The correct version number should be 2.1.0.dev2023xxxx, not 2.0.0).

Currently you must [load the model locally](README_en.md#load-the-model-locally) on MacOS. Change the code to load the model from your local path, and use the mps backend:
```python
model = AutoModel.from_pretrained("your local path", trust_remote_code=True).half().to('mps')
```

Loading a FP16 ChatGLM-6B model requires about 13GB of memory. Machines with less memory (such as a MacBook Pro with 16GB of memory) will use the virtual memory on the hard disk when there is insufficient free memory, resulting in a serious slowdown in inference speed. At this time, a quantized model such as chatglm-6b-int4 can be used. Because the quantized kernel on the GPU is written in CUDA, it cannot be used on MacOS, and can only be inferred using the CPU:

```python
# For INT8-quantized model, change "chatglm-6b-int4" to "chatglm-6b-int8"
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).float()
```

### Multi-GPU Deployment
If you have multiple GPUs, but the memory size of each GPU is not sufficient to accommodate the entire model, you can split the model across multiple GPUs. 

First, install accelerate: `pip install accelerate`, and then load the model using the following method:
```python
from utils import load_model_on_gpus
model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2)
```

This will deploy the model onto two GPUs for inference. You can change `num_gpus` to the number of GPUs you want to use. By default, the model is split evenly, but you can also specify the `device_map` parameter to customize the splitting.

## Parameter-efficient Tuning
Parameter-efficient tuning based on [P-tuning v2](https://github.com/THUDM/P-tuning-v2). See [ptuning/README.md](ptuning/README.md) for details on how to use it.

## ChatGLM-6B Examples

The following are some Chinese examples with `web_demo.py`. Welcome to explore more possibility with ChatGLM-6B.

<details><summary><b>Self Cognition</b></summary>

![](examples/self-introduction.png)

</details>

<details><summary><b>Outline</b></summary>

![](examples/blog-outline.png)

</details>

<details><summary><b>Ad</b></summary>

![](examples/ad-writing-2.png)

![](examples/comments-writing.png)

</details>

<details><summary><b>Email</b></summary>

![](examples/email-writing-1.png)

![](examples/email-writing-2.png)

</details>

<details><summary><b>Information Extraction</b></summary>

![](examples/information-extraction.png)

</details>

<details><summary><b>Role Play</b></summary>

![](examples/role-play.png)

</details>

<details><summary><b>Comparison</b></summary>

![](examples/sport.png)

</details>

<details><summary><b>Travel Guide</b></summary>

![](examples/tour-guide.png)

</details>

## License

This repository is licensed under the [Apache-2.0 License](LICENSE). The use of ChatGLM-6B model weights is subject to the [Model License](MODEL_LICENSE)。

## Citation

If you find our work useful, please consider citing the following papers:

```
@misc{glm2024chatglm,
      title={ChatGLM: A Family of Large Language Models from GLM-130B to GLM-4 All Tools}, 
      author={Team GLM and Aohan Zeng and Bin Xu and Bowen Wang and Chenhui Zhang and Da Yin and Diego Rojas and Guanyu Feng and Hanlin Zhao and Hanyu Lai and Hao Yu and Hongning Wang and Jiadai Sun and Jiajie Zhang and Jiale Cheng and Jiayi Gui and Jie Tang and Jing Zhang and Juanzi Li and Lei Zhao and Lindong Wu and Lucen Zhong and Mingdao Liu and Minlie Huang and Peng Zhang and Qinkai Zheng and Rui Lu and Shuaiqi Duan and Shudan Zhang and Shulin Cao and Shuxun Yang and Weng Lam Tam and Wenyi Zhao and Xiao Liu and Xiao Xia and Xiaohan Zhang and Xiaotao Gu and Xin Lv and Xinghan Liu and Xinyi Liu and Xinyue Yang and Xixuan Song and Xunkai Zhang and Yifan An and Yifan Xu and Yilin Niu and Yuantao Yang and Yueyan Li and Yushi Bai and Yuxiao Dong and Zehan Qi and Zhaoyu Wang and Zhen Yang and Zhengxiao Du and Zhenyu Hou and Zihan Wang},
      year={2024},
      eprint={2406.12793},
      archivePrefix={arXiv},
      primaryClass={id='cs.CL' full_name='Computation and Language' is_active=True alt_name='cmp-lg' in_archive='cs' is_general=False description='Covers natural language processing. Roughly includes material in ACM Subject Class I.2.7. Note that work on artificial languages (programming languages, logics, formal systems) that does not explicitly address natural-language issues broadly construed (natural-language processing, computational linguistics, speech, text retrieval, etc.) is not appropriate for this area.'}
}
```

================================================
FILE: UPDATE.md
================================================
## 更新信息
**[2023/05/17]** 发布 [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B),一个支持图像理解的多模态对话语言模型。

![](resources/visualglm.png)

可以通过本仓库中的 [cli_demo_vision.py](cli_demo_vision.py) 和 [web_demo_vision.py](web_demo_vision.py) 来运行命令行和网页 Demo。注意 VisualGLM-6B 需要额外安装 [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer/) 和 torchvision。更多信息参见 [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B)。

**[2023/05/15]** 更新 v1.1 版本 checkpoint,训练数据增加英文数据以平衡中英文数据比例,解决英文回答中夹杂中文词语的现象。

<details><summary><b>以下是更新前后的英文问题对比:</b></summary>

* 问题:Describe a time when you had to make a difficult decision.
  - v1.0:
  ![](resources/english-q1-old.png)
  - v1.1:
  ![](resources/english-q1-new.png)
* 问题:Describe the function of a computer motherboard
  - v1.0:
  ![](resources/english-q2-old.png)
  - v1.1: 
  ![](resources/english-q2-new.png)
* 问题:Develop a plan to reduce electricity usage in a home.
  - v1.0:
  ![](resources/english-q3-old.png)
  - v1.1:
  ![](resources/english-q3-new.png)
* 问题:未来的NFT,可能真实定义一种现实的资产,它会是一处房产,一辆汽车,一片土地等等,这样的数字凭证可能比真实的东西更有价值,你可以随时交易和使用,在虚拟和现实中无缝的让拥有的资产继续创造价值,未来会是万物归我所用,但不归我所有的时代。翻译成专业的英语
  - v1.0:
  ![](resources/english-q4-old.png)
  - v1.1:
  ![](resources/english-q4-new.png)
</details>

**[2023/04/16]** 增加 INT8 量化后的模型 [ChatGLM-6B-INT8](https://huggingface.co/THUDM/chatglm-6b-int8)。增加多卡部署(感谢 [@Cherrysaber](https://github.com/Cherrysaber))。

**[2023/04/06]** 优化web demo的界面(感谢 [@tuteng0915](https://github.com/tuteng0915))。移除embedding中的image token以减小显存占用(需要更新模型文件`pytorch_model-00001-of-00008.bin`和`pytorch_model-00008-of-00008.bin`,感谢 [@silverriver](https://github.com/silverriver) 提出的想法)。去掉了对 `icetk` 的依赖(需要更新模型文件`ice_text.model`)。

**[2023/03/31]** 增加基于 [P-Tuning-v2](https://github.com/THUDM/P-tuning-v2) 的高效参数微调实现,INT4 量化级别下最低只需 7GB 显存即可进行模型微调。详见[高效参数微调方法](ptuning/README.md)。

**[2023/03/23]** 增加 API 部署(感谢 [@LemonQu-GIT](https://github.com/LemonQu-GIT))。~~增加 Embedding 量化模型 [ChatGLM-6B-INT4-QE](https://huggingface.co/THUDM/chatglm-6b-int4-qe)~~ (已停止维护)。增加配备 Apple Silicon 芯片的 Mac 上 GPU 加速的支持。

**[2023/03/19]** 增加流式输出接口 `stream_chat`,已更新到网页版和命令行 Demo。修复输出中的中文标点。增加 INT4 量化后的模型 [ChatGLM-6B-INT4](https://huggingface.co/THUDM/chatglm-6b-int4)


## Update
**[2023/05/17]** Release [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B), a multimodal conversational language model supporting image understanding.

![](resources/visualglm.png)

You can run the command line and web demo through [cli_demo_vision.py](cli_demo_vision.py) and [web_demo_vision.py](web_demo_vision.py) in the repository. Note that VisualGLM-6B requires additional installation of [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer/) and torchvision. For more information, please refer to  [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B).

**[2023/05/15]** Update the checkpoint of v1.1 version, add English instruction data for training to balance the proportion of Chinese and English data, which solves the phenomenon of Chinese words mixed in English answers .

<details><summary><b>The following is a comparison of English questions before and after the update</b></summary>

* Question: Describe a time when you had to make a difficult decision.
  - v1.0:
  ![](resources/english-q1-old.png)
  - v1.1:
  ![](resources/english-q1-new.png)
* Question: Describe the function of a computer motherboard
  - v1.0:
  ![](resources/english-q2-old.png)
  - v1.1:
  ![](resources/english-q2-new.png)
* Question: Develop a plan to reduce electricity usage in a home.
  - v1.0:
  ![](resources/english-q3-old.png)
  - v1.1:
  ![](resources/english-q3-new.png)
* Question:未来的NFT,可能真实定义一种现实的资产,它会是一处房产,一辆汽车,一片土地等等,这样的数字凭证可能比真实的东西更有价值,你可以随时交易和使用,在虚拟和现实中无缝的让拥有的资产继续创造价值,未来会是万物归我所用,但不归我所有的时代。翻译成专业的英语
  - v1.0:
  ![](resources/english-q4-old.png)
  - v1.1:
  ![](resources/english-q4-new.png)
</details>

**[2023/04/16]** Added INT8 quantized model [ChatGLM-6B-INT8](https://huggingface.co/THUDM/chatglm-6b-int8). Added multi-GPU deployment (thanks to [@Cherrysaber](https://github.com/Cherrysaber)).

**[2023/04/06]** Improve the web demo interface (thanks to [@tuteng0915](https://github.com/tuteng0915)). Remove the image tokens in the embedding layer to reduce the memory usage (need to update the model files `pytorch_model-00001-of-00008.bin` and `pytorch_model-00008-of-00008.bin`, thanks to [@silverriver](https:/ /github.com/silverriver) for proposing the idea). Removed dependency on `icetk` (need to update model file `ice_text.model`).

**[2023/03/31]** Added a parameter-efficient tuning implementation based on [P-Tuning-v2](https://github.com/THUDM/P-tuning-v2). The minimum INT4 quantization level only needs 7GB GPU memory is enough for model tuning. See [Parameter-efficient tuning method](ptuning/README.md) for details.

**[2023/03/23]** Add API deployment, thanks to [@LemonQu-GIT](https://github.com/LemonQu-GIT). Add embedding-quantized model [ChatGLM-6B-INT4-QE](https://huggingface.co/THUDM/chatglm-6b-int4-qe). Add support for GPU inference on Mac with Apple Silicon.

**[2023/03/19]** Add streaming output function `stream_chat`, already applied in web and CLI demo. Fix Chinese punctuations in output. Add quantized model [ChatGLM-6B-INT4](https://huggingface.co/THUDM/chatglm-6b-int4). 

================================================
FILE: api.py
================================================
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime
import torch

DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE


def torch_gc():
    if torch.cuda.is_available():
        with torch.cuda.device(CUDA_DEVICE):
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()


app = FastAPI()


@app.post("/")
async def create_item(request: Request):
    global model, tokenizer
    json_post_raw = await request.json()
    json_post = json.dumps(json_post_raw)
    json_post_list = json.loads(json_post)
    prompt = json_post_list.get('prompt')
    history = json_post_list.get('history')
    max_length = json_post_list.get('max_length')
    top_p = json_post_list.get('top_p')
    temperature = json_post_list.get('temperature')
    response, history = model.chat(tokenizer,
                                   prompt,
                                   history=history,
                                   max_length=max_length if max_length else 2048,
                                   top_p=top_p if top_p else 0.7,
                                   temperature=temperature if temperature else 0.95)
    now = datetime.datetime.now()
    time = now.strftime("%Y-%m-%d %H:%M:%S")
    answer = {
        "response": response,
        "history": history,
        "status": 200,
        "time": time
    }
    log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
    print(log)
    torch_gc()
    return answer


if __name__ == '__main__':
    tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
    model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
    model.eval()
    uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)


================================================
FILE: cli_demo.py
================================================
import os
import platform
import signal
from transformers import AutoTokenizer, AutoModel
import readline

tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model = model.eval()

os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
stop_stream = False


def build_prompt(history):
    prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
    for query, response in history:
        prompt += f"\n\n用户:{query}"
        prompt += f"\n\nChatGLM-6B:{response}"
    return prompt


def signal_handler(signal, frame):
    global stop_stream
    stop_stream = True


def main():
    history = []
    global stop_stream
    print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
    while True:
        query = input("\n用户:")
        if query.strip() == "stop":
            break
        if query.strip() == "clear":
            history = []
            os.system(clear_command)
            print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
            continue
        count = 0
        for response, history in model.stream_chat(tokenizer, query, history=history):
            if stop_stream:
                stop_stream = False
                break
            else:
                count += 1
                if count % 8 == 0:
                    os.system(clear_command)
                    print(build_prompt(history), flush=True)
                    signal.signal(signal.SIGINT, signal_handler)
        os.system(clear_command)
        print(build_prompt(history), flush=True)


if __name__ == "__main__":
    main()


================================================
FILE: cli_demo_vision.py
================================================
import os
import platform
import signal
import sys

from transformers import AutoTokenizer, AutoModel
import readline

tokenizer = AutoTokenizer.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).half().cuda()
model = model.eval()

os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
stop_stream = False


def build_prompt(history, prefix):
    prompt = prefix
    for query, response in history:
        prompt += f"\n\n用户:{query}"
        prompt += f"\n\nChatGLM-6B:{response}"
    return prompt


def signal_handler(signal, frame):
    global stop_stream
    stop_stream = True


def main():
    global stop_stream
    while True:
        history = []
        prefix = "欢迎使用 VisualGLM-6B 模型,输入图片路径和内容即可进行对话,clear 清空对话历史,stop 终止程序"
        print(prefix)
        image_path = input("\n请输入图片路径:")
        if image_path == "stop":
            break
        prefix = prefix + "\n" + image_path
        query = "描述这张图片。"
        while True:
            count = 0
            for response, history in model.stream_chat(tokenizer, image_path, query, history=history):
                if stop_stream:
                    stop_stream = False
                    break
                else:
                    count += 1
                    if count % 8 == 0:
                        os.system(clear_command)
                        print(build_prompt(history, prefix), flush=True)
                        signal.signal(signal.SIGINT, signal_handler)
            os.system(clear_command)
            print(build_prompt(history, prefix), flush=True)
            query = input("\n用户:")
            if query.strip() == "clear":
                break
            if query.strip() == "stop":
                sys.exit(0)


if __name__ == "__main__":
    main()


================================================
FILE: improve/README.md
================================================
# ChatGLM-6B Badcase 反馈计划
ChatGLM-6B 自3月14号发布以来受到了广大开发者和用户的喜爱,截至4月22号 GitHub 的 star 数达到 2 万,各个渠道模型的累计下载量过 100 万,并连续 12 天居 Hugging Face 全球大模型下载榜第一名。 与此同时,有一批基于 ChatGLM-6B 的[优秀开源项目](https://github.com/THUDM/ChatGLM-6B)出现,在各个平台也引起了广泛好评和关注。此外,基于 GLM-130B 的千亿对话模型 ChatGLM 也自3月14号开始了第一阶段的邀请制内测,得到了内测用户的好评和支持。谢谢大家对 ChatGLM 及其 6B 开源版本的大力支持!

接下来,我们想邀请大家一起推动 ChatGLM-6B 的进一步提升,一起推动模型的发展。尽管ChatGLM-6B已初具符合人类偏好的问答对话能力,在相当多的指令和问题上,其回答仍存在不理解复杂指令和任务含义,缺乏领域概念理解,事实性错误,生成有害内容,对话上下文不一致等诸多问题。尽管我们提供的[微调代码](https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning)能够让用户通过自主训练修复部分问题,但因为神经网络的[灾难性遗忘](https://picture.iczhiku.com/weixin/message1587593113355.html)问题,微调后的模型往往会失去在通用领域的对话能力或者因数据较少而缺乏泛化能力。为了解决这些问题,进一步提升 ChatGLM-6B 的能力,我们启动了 ChatGLM-6B Badcase 反馈计划。

具体来说,对于在使用 ChatGLM-6B 过程中遇到的表现不佳的Badcase对应的具体指令和提问,您可以修改或从头撰写您认为合适的正确答案,并反馈给我们改进 ChatGLM-6B。**请您确保提交的数据不包含任何个人信息、商业秘密或可能危害国家安全、侵害第三方知识产权的内容。** 我们会定期(每2-4周)对数据的有用性与正确性进行筛选,将筛选通过的数据,与通用域的对话数据一起加入到模型训练中,并**更新发布开源的模型参数**。**您提供的数据无论是否筛选通过,除非获得您的许可或根据国家法律规定和监管要求外,我们不会将您提供的数据对外公开。**

您提供的数据如被筛选通过,您将同时优先获得最新版本的 ChatGLM-6B 模型的体验资格。此外,如果您愿意,您的用户名还将出现在 ChatGLM-6B Github页面的数据贡献者名单中,以此感谢您对推进大模型开源事业发展的帮助。您的支持和建议将为我们优化大模型提供源源不断的动力,在此向您表达我们最深的敬意与感谢!

如果您希望参与反馈,请填写[问卷](https://www.wjx.cn/vm/rAoGx9X.aspx#)并按照具体要求上传。提交的数据为 jsonline 格式,每行的内容为
```json lines
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*版型#宽松*风格#性感*图案#线条*裤型#阔腿x s裤", "response": "宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效果宽松的裤腿,当然是遮肉小能手啊。上身随性自然不拘束,面料亲肤舒适贴身体验感棒棒哒。系带部分增加设计看点,还让单品的设计感更强。腿部线条若隐若现的,性感撩人。颜色敲温柔的,与裤子本身所呈现的风格有点反差萌。"}
```
其中,`prompt` 部分为模型的输入,`response` 部分为期望的模型输出。为了保证模型的性能,请在输入中尽可能详细地表述任务的类型和期望的输出格式。针对某项具体的任务,为了使模型得到充分的训练,一般需要100条左右的训练数据。


================================================
FILE: improve/data_sample.jsonl
================================================
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*版型#宽松*风格#性感*图案#线条*裤型#阔腿裤", "response": "宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效果宽松的裤腿,当然是遮肉小能手啊。上身随性自然不拘束,面料亲肤舒适贴身体验感棒棒哒。系带部分增加设计看点,还让单品的设计感更强。腿部线条若隐若现的,性感撩人。颜色敲温柔的,与裤子本身所呈现的风格有点反差萌。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#简约*图案#条纹*图案#线条*图案#撞色*裙型#鱼尾裙*裙袖长#无袖", "response": "圆形领口修饰脖颈线条,适合各种脸型,耐看有气质。无袖设计,尤显清凉,简约横条纹装饰,使得整身人鱼造型更为生动立体。加之撞色的鱼尾下摆,深邃富有诗意。收腰包臀,修饰女性身体曲线,结合别出心裁的鱼尾裙摆设计,勾勒出自然流畅的身体轮廓,展现了婀娜多姿的迷人姿态。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*颜色#粉红色*图案#字母*图案#文字*图案#线条*衣样式#卫衣*衣款式#不规则", "response": "宽松的卫衣版型包裹着整个身材,宽大的衣身与身材形成鲜明的对比描绘出纤瘦的身形。下摆与袖口的不规则剪裁设计,彰显出时尚前卫的形态。被剪裁过的样式呈现出布条状自然地垂坠下来,别具有一番设计感。线条分明的字母样式有着花式的外观,棱角分明加上具有少女元气的枣红色十分有年轻活力感。粉红色的衣身把肌肤衬托得很白嫩又健康。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#宽松*材质#雪纺*风格#清新*裙型#a字*裙长#连衣裙", "response": "踩着轻盈的步伐享受在午后的和煦风中,让放松与惬意感为你免去一身的压力与束缚,仿佛要将灵魂也寄托在随风摇曳的雪纺连衣裙上,吐露出<UNK>微妙而又浪漫的清新之意。宽松的a字版型除了能够带来足够的空间,也能以上窄下宽的方式强化立体层次,携带出自然优雅的曼妙体验。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#棉*颜色#蓝色*风格#潮*衣样式#polo*衣领型#polo领*衣袖长#短袖*衣款式#拼接", "response": "想要在人群中脱颖而出吗?那么最适合您的莫过于这款polo衫短袖,采用了经典的polo领口和柔软纯棉面料,让您紧跟时尚潮流。再配合上潮流的蓝色拼接设计,使您的风格更加出众。就算单从选料上来说,这款polo衫的颜色沉稳经典,是这个季度十分受大众喜爱的风格了,而且兼具舒适感和时尚感。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#h*材质#蚕丝*风格#复古*图案#条纹*图案#复古*图案#撞色*衣样式#衬衫*衣领型#小立领", "response": "小女人十足的条纹衬衣,缎面一点点的复古,还有蓝绿色这种高级气质复古色,真丝材质,撞色竖条纹特别的现代感味道,直h型的裁剪和特别的衣长款式,更加独立性格。双层小立领,更显脸型。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#网纱*颜色#粉红色*图案#线条*图案#刺绣*裙腰型#高腰*裙长#连衣裙*裙袖长#短袖*裙领型#圆领", "response": "这款连衣裙,由上到下都透出一丝迷人诱惑的女性魅力,经典圆领型,开口度恰好,露出你的迷人修长的脖颈线条,很是优雅气质,短袖设计,在这款上竟是撩人美貌,高腰线,散开的裙摆,到小腿的长度,遮住了腿部粗的部分,对身材有很好的修饰作用,穿起来很女神;裙身粉红色花枝重工刺绣,让人一眼难忘!而且在这种网纱面料上做繁复图案的绣花,是很考验工艺的,对机器的要求会更高,更加凸显我们的高品质做工;"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#纯色*图案#纯色*图案#文字*图案#印花*衣样式#卫衣", "response": "一款非常简洁大方的纯色卫衣,设计点在于胸前的“<UNK><UNK>”的中文字印花,新颖特别,让人眼前一亮。简单又吸睛的款式,而且不失时髦感,很适合个性年轻人。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*颜色#黑色*颜色#灰色*颜色#姜黄色*风格#休闲*图案#线条*图案#撞色*衣样式#毛衣*衣袖型#落肩袖", "response": "看惯了灰色的冷淡和黑色的沉闷感,来一点醒目的彩色增添点活力吧。亮眼又吸睛的姜黄色色调,嫩肤显白非常的有设计感。趣味的撞色和宽松的版型相交辉映,修饰身形小缺点的同时,时尚又百搭。优雅的落肩袖,轻松修饰肩部线条,让毛衣上身凸显出一丝慵懒随性的休闲感,时尚魅力尽显。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#休闲*风格#潮*图案#印花*图案#撞色*衣样式#衬衫*衣领型#圆领*衣长#中长款*衣长#常规*衣袖长#无袖", "response": "黑与白,两种最极端的颜色却轻松搭配成了经典,就像此款衬衣,无需过多装饰,仅色调就足够醒目个性,受潮<UNK>所喜欢。做了无袖中长款的样式,走路带风的感觉着实不错,圆领的设计,不是常规的衬衫领,少了点正式反而有种休闲感觉,适合孩子们穿着。后背大面积撞色印花装点,是时尚潮流的象征,也让衣衣不至于单调,轻松就能穿出彩。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*风格#街头*风格#休闲*风格#朋克*图案#字母*图案#文字*图案#印花*衣样式#卫衣*衣款式#连帽*衣款式#对称", "response": "个性休闲风的连帽卫衣造型时髦大方,宽松的版型剪裁让肉肉的小宝贝也可以穿着,保暖的连帽设计时刻给予宝贝温柔的呵护,袖子和后背别致时髦的字母印花点缀,满满的街头元素融入,演绎休闲朋克风,对称的小口袋美观大方,方便放置更多的随身物品。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*裙款式#链条", "response": "简单大气的设计,不费吹灰之力就能搭配的时髦范儿。时尚的配色一点都不觉得平淡了,有种浑然天成的大气感。强调了整体的装饰,和谐又不失个性,搭配裤装帅气十足,搭配裙子精致优雅。链条和肩带的搭配让使用感更加舒服,单肩手提都好看。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*版型#显瘦*材质#牛仔布*颜色#深蓝色*风格#复古*图案#复古*图案#线条*裤腰型#高腰*裤口#微喇裤", "response": "深蓝色的高腰牛仔裤,修身的款式勾勒出纤细的美腿。牛仔裤的裤脚设计<UNK>张开的喇叭型,巧妙地修饰了小腿的线条,洋溢着复古的年代感。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#清新*风格#潮*风格#性感*图案#条纹*图案#蝴蝶结*衣样式#衬衫*衣领型#一字领*衣门襟#系带*衣款式#不对称", "response": "这是一件显得特别清新的衬衣,采用了条纹的设计,给予人一种甜美可人的气质。并且融合了别致的斜肩一字领设计,高调的展示出性感的锁骨,将迷人的香肩展现在外,性感中不失去清纯的气息。袖口处的蝴蝶结系带装饰,增添了俏皮的韵味,简洁大方。且在下摆处采用了不对称的设计,增强了视觉效果,更显潮流。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#牛仔布*风格#复古*图案#复古*裤型#直筒裤*裤款式#纽扣*裤腰型#高腰", "response": "作为基础款单品,牛仔裤也<UNK><UNK>,想要呈现给大家的是——每次搭配都有新感觉。裤子经过复古做旧处理,风格鲜明,也很注重细节,连纽扣也做了统一的做旧处理,融入个性十足的磨破设计,高腰直筒basic裤型,修饰身材,穿出高挑长腿。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*版型#显瘦*图案#线条*图案#刺绣*衣样式#针织衫*衣领型#v领", "response": "一款温暖柔软又富有弹性的针织衫,不仅可以抵御严寒侵袭,还能更好地进行搭配。v领的设计,能勾勒出迷人的天鹅颈以及衬托出娇小的脸型。宽松又别致的剪裁,能从视觉上显露纤长的下半身,起到显瘦的效果。直筒造型的袖子,修饰出优美的手臂线条,衣身上的方格刺绣,时尚又吸睛。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#绿色*风格#清新*图案#线条*衣样式#衬衫*衣领型#翻领", "response": "绿色的衣身上镶嵌着<UNK>,就是这款衬衫最大的迷人之处,“红花配绿叶”般的色调,将清新气息阐述的淋漓尽致。经典的翻领更是贴心,修饰颈部线条的同时,尽显精致干练的气质,出街轻松凹造型。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*图案#字母*图案#文字*图案#印花*图案#撞色*衣样式#外套*衣门襟#拉链*衣款式#拉链", "response": "这款外套采用了撞色拉链织带以及字母印花设计。这两种元素的融入使外套不会显得过于单调沉闷,吸睛而亮眼,充满年轻与朝气感,非常减龄。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#显瘦*版型#h*风格#复古*图案#复古*图案#刺绣*裙长#连衣裙*裙袖长#长袖*裙领型#翻领*裙衣门襟#单排扣", "response": "本款连衣裙整体采用h型的轮廓设计,藏肉显瘦,不挑身材,适合各种身形的人穿着。小翻领的领口设计,使得本款连衣裙穿在身上看起来十分的精神帅气,具有青春活力。单排扣的衣门襟设计,又给本款连衣裙带来了一丝的复古味道。裙身上的刺绣花朵装饰,使得本款连衣裙不显得单调,富有层次感,上身给人一种独特的时尚魅力。长袖的设计,更加的贴合手臂曲线,上身更加的舒适贴身。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#粉色*风格#清新*衣样式#外套*衣样式#西装*衣门襟#双排扣", "response": "这款外套设计成西装的版型,彰显经典优雅的气质,结合了粉色又添清新气息,甜美百搭时尚感满满。利落的版型简洁流畅,亮色双排扣更添精致感。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#休闲*图案#线条*衣样式#风衣*衣样式#外套*衣门襟#拉链*衣款式#拉链*衣款式#松紧带*衣款式#连帽*衣款式#收腰", "response": "选自品牌江南布衣的一款女士长风衣外套,选用轻薄的<UNK><UNK>,穿着灵活毫无压力。直筒版型简洁利落,长过膝盖的长度穿着个性十足,连帽宽大有型,富有活力,<UNK>拉链开合,拉上拉链有一丝酷劲,敞开穿则更休闲,连帽领翻开修饰颈部线条。松紧带收腰设计,低调的分割上下比例,打造显高小心机。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#棉*材质#牛仔布*风格#街头*风格#简约*图案#刺绣*裤长#短裤*裤款式#钉珠*裤口#毛边", "response": "又到了光腿穿裙子和短裤的时候了,BRAND的这款短裤,采用柔软透气的纯棉牛仔面料,穿着舒适无束缚感。而简约的版型加入了精美的刺绣和钉珠装饰,提升了整体的品质感,显得精美而又立体饱满。搭配下摆的毛边装饰,散发出不羁的街头感。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#牛仔布*颜色#黑色*图案#条纹*衣样式#衬衫*衣领型#翻领*衣门襟#系带*衣款式#拼接*衣款式#露肩", "response": "一款老鹰图案露肩衬衫,露肩系带的设计,少女感十足。老鹰图案的设计,更添几分趣味感。条纹面料和牛仔面料的拼接设计,给人一种风度的层次感。小翻领的设计十分的精致,搭配一件黑色打底裤也吸晴万分。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#雪纺*裙型#百褶*裙长#半身裙*裙款式#拼接*裙款式#腰带", "response": "一款颇有设计感的半身裙,单侧雪纺百褶的拼接设计,规整排列的层次感带来立体效果,增加了裙身的廓形,行走间更是带来柔美的灵动气息,轻而易举穿出优雅的轻熟风,呈现十足的女人味来。同面料延伸处理的半固定腰带,可以自然的垂落下来,也算是为整体打造造型亮点,彰显你独特的时尚品味,迎合早春对轻盈的追求。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#显瘦*风格#性感*裙型#包臀裙*裙型#鱼尾裙", "response": "修身包臀版型结合性感鱼尾裙摆设计,彰显婉约优雅风情之余,为整体注入几分俏皮灵动气息。且下摆辅以律动感摺裥元素,更烘托出女性浪漫精致的一面。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*颜色#绿色*图案#线条*裙长#连衣裙*裙领型#v领*裙款式#勾花镂空", "response": "连衣裙可以让你在旋转与跳跃间,都散发出万种风情,受到了万千女性的喜爱。这款连衣裙选用绿色调,既散发出活力气息,又增添了高雅的气质。而镂空的钩花设计,则为其增添了浪漫的风情,同时更显美观与时尚。再加上v领的设计,不仅映衬出精致的脸颊,还打造出优美的颈部线条。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#淑女*衣样式#毛衣*衣领型#高领", "response": "高领毛衣一直是网红妹子,因为穿着它有一种淑女甜美气质。它最大的亮点在于它的高领设计和花边装饰。在淑女干练的气质基础上又增加了一些少女的甜美气息,穿着非常有型,最佳搭配小白鞋。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*裤长#九分裤*裤型#阔腿裤*裤款式#拉链*裤腰型#高腰*裤口#开叉", "response": "九分裤长,把妹子的拉长了腿的比例,配合高腰设计,瞬间显得妹子的腿长了很多,一下子自信满满啦。采用侧面隐藏拉链设计,穿脱方便又舒适。设计感十足的开叉裤脚,身上的摩登<UNK>浓了。这个春天妹子的腿型,就交给阔腿裤啦。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#黑白*风格#复古*风格#文艺*图案#格子*图案#复古*衣样式#衬衫*衣领型#翻领*衣门襟#套头*衣款式#纽扣", "response": "经典的套头翻领衬衫与黑白格纹元素组合,一直以来的气场经久不衰。而采用复古精致的纽扣装点的半门襟设计,简单的小细节处理,彰显出浓浓的文艺气息。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#黑白*风格#复古*风格#文艺*图案#格子*图案#复古*衣样式#衬衫*衣领型#翻领*衣门襟#套头*衣款式#纽扣", "response": "套头翻领衬衫与黑白格纹元素组合,一直以来的气场经久不衰。而采用复古精致的纽扣装点的半门襟设计,简单的小细节处理,彰显出浓浓的文艺气息。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*图案#卡通*裙长#连衣裙", "response": "传奇而又经典的卡通形象,米老鼠似乎已经成为了孩童风格的一种标志,大小不一的头像以及奇趣的表情设计。满版的点缀风格让整个连衣裙洋溢着独特的天真气质,加之面料小口袋的点缀,小小的造型呈现出灵巧而又可爱的格调,让宝贝俏皮萌动。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*图案#条纹*图案#刺绣*裙型#背带裙*裙下摆#毛边", "response": "假两件版型的设计,给人一种错觉,大大增添自身时髦感。毛边裙摆的采用,看起来活力十足。设计师解决了以往穿脱不方便的问题,应用的可调节背带设计,非常的人性化。裙子上的花朵刺绣图案,看起来也栩栩如生,同时也展示出了精湛的做工手艺。为了与女人自身清纯的一面形成呼应,应用的条纹图案非常完美。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*版型#宽松*材质#牛仔布*风格#街头*风格#休闲*裤长#五分裤*裤腰型#松紧腰", "response": "这款休闲五分裤,采用亲肤软牛仔,洗水磨白形成深浅对比,更加个性。大弹力松紧腰,舒适贴合,一点都不紧勒。裤子门襟时尚的设计,为细节加分。立体双贴袋,腰间系带的点缀更吸睛。精致的裁剪,或是干练整洁的走线和宽松版型,是对街头的描写。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#蕾丝*图案#条纹*图案#蕾丝*衣样式#衬衫*衣领型#立领*衣款式#荷叶边", "response": "条纹衬衫是引领时尚圈的常青树,尤其给人舒适感官享受的蓝白条纹,更是深得时尚icon的喜爱。加之搭载经典的立领秀出纤长的玉颈,更显气质优雅。肩膀上饰有薄薄的蕾丝,打破了条纹衬衫的干练,更添别样风情。荷叶边的蔓延更显气质甜美,自然吸睛无数。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#牛仔布*风格#休闲*图案#线条*裙型#牛仔裙*裙长#半身裙", "response": "牛仔半身裙作为时尚宠儿,一直被很多潮人捧在手心,zui妙的莫过于它<UNK>的时髦感以及百搭易驾驭的属性。裙身设计了自然的腰型,拉长腿部线条,让小仙女们感受到大长腿。以及两侧插袋的造型,显得比较随意,休闲的感觉。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#纯色*图案#纯色*图案#刺绣*衣样式#卫衣*衣袖型#罗纹袖口", "response": "乍一看很平凡的纯色系卫衣,<UNK>暗藏的刺绣玄机,就足够把时髦的张力表现得不凡。很有包容性的廓形,舒适的罗纹收口,宽大的样子却依旧<UNK>既定的风格,让你的潇洒随性表现得收放自如。呆萌查理的袖间刺绣,极简的漫画笔触巧塑生动有趣的风格。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*衣样式#衬衫*衣样式#风衣*衣款式#抽绳", "response": "风衣带有一种随性大气的感觉,在春风拂面的日子里能衬托出你的气质。草绿色的衣身配色,与与自身清纯干净的性格形成了呼应。具有一定实用性的下摆抽绳,可以让你任意的变换风格。抛弃了衬衫领的设计应用的设计,更能将你帅气的一面展示出来。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#职场*图案#线条*裙型#包臀裙*裙下摆#开叉", "response": "优质挺括的面料。包臀版型,长度在膝盖往下一点,显得利落而大方,与生俱来的气场感。适合职场女性,包臀裙的优势在于凸显腰身线条,侧边开叉的设计不仅让整体造型更具曾层次感,也增添一份恰到好处的妩媚"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#职场*图案#线条*裙型#包臀裙*裙下摆#开叉", "response": "优质挺括的面料。包臀版型,长度在膝盖往下一点,显得利落而大方,与生俱来的气场感。适合职场女性,包臀裙的优势在于凸显腰身线条,侧边开叉的设计不仅让整体造型更具曾层次感,也增添一份恰到好处的妩媚。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#职场*图案#线条*裙型#包臀裙*裙下摆#开叉", "response": "长度在膝盖往下一点,显得利落而大方,与生俱来的气场感。适合职场女性,包臀裙的优势在于凸显腰身线条,侧边开叉的设计不仅让整体造型更具曾层次感,也增添一份恰到好处的妩媚"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#针织*风格#复古*风格#清新*图案#条纹*图案#复古*衣样式#针织衫*衣样式#开衫*衣长#常规*衣款式#拼接*衣款式#纽扣*衣款式#罗纹", "response": "慵懒气质的针织开衫,充满了复古的情调,奶奶级的麻花编织手法,充满立体感的同时保暖效果也是满分。下摆的罗纹拼接,让针织衫回暖性更棒。活泼的<UNK>条纹拼接,跳脱出常规配色,清新色调的选用,更加衬托出肌肤的雪白。精致的纽扣点缀,反光的质感让针织衫充满现代感。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#复古*图案#蝴蝶结*图案#复古*图案#波点*衣样式#衬衫*衣领型#立领*衣门襟#系带*衣款式#木耳", "response": "【<UNK>说】<UNK>衬衫,大波点气质复古从立领上延伸的长系带,可轻松绑成蝴蝶结,甜美感加分采用打缆工艺的松紧袖口边边处的木耳<UNK>很可爱"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#简约*风格#青春*图案#字母*图案#文字*裙型#网纱裙*裙袖长#无袖*裙领型#圆领", "response": "大气的圆领舒适贴合,彰显出女孩儿精神的气质。无袖的款式与圆领相迎合,简约的同时又不失时尚风采。前身由可爱蝴蝶图案点缀,亮丽的字母映衬其上,诉说着一丝精美感。橙色网纱裙摆造型优雅唯美,与上身的图案相呼应,十分富有青春的气息,伴随着步伐的行走间,带出一丝别致浪漫的风情。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#丝绒*风格#复古*图案#复古*衣样式#雪纺衫*衣袖型#喇叭袖*衣款式#木耳边*衣款式#飘带*衣款式#荷叶边", "response": "这款雪纺衫,采用具有复古韵味的荷叶边元素,加上丝绒质感的加长飘带,洋溢着浪漫古典的韵味。<UNK>两侧镶有包扣,和立体木耳边装饰,大大提升时髦指数。而流线型喇叭袖设计,充满灵动质感,为造型平添活力。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#宽松*版型#显瘦*裙下摆#荷叶边*裙腰型#高腰*裙长#半身裙", "response": "很简洁百搭的一款半裙,裙身荷叶边设计,飘逸灵动,上身更显层次感丰富。高腰造型,版型优良,衬显修长双腿。裙子做的比较宽松,包容性敲好,遮肉效果棒棒的。非常的显瘦哦,选用精品梭织面料,垂感好,肌理细致,上身敲舒服哟。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#青春*风格#性感*图案#线条*裙下摆#开叉*裙长#连衣裙*裙领型#翻领*裙款式#腰带*裙款式#衬衫式", "response": "设计师以衬衫式的创作灵感,巧妙地搬运到连衣裙身上,中性又不失性感;时尚小翻领设计,巧妙衬托颈部线条,彰显青春派的艺术时尚,小资派的精彩演绎。耳目一新的双腰带设计,既突出了腰线又感觉很前卫;下面走心的大开叉设计,更能激发人的好奇心,营造出无人超越的高级性感,只需一眼就令人<UNK>。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#宽松*风格#性感*图案#印花*裙下摆#荷叶边*裙长#连衣裙*裙袖型#灯笼袖", "response": "这款连衣裙走的是性感大方的风格路线,展现出你的大大咧咧的性情,非常的有趣。选用了宽松的版型,配合星空印花的图案,塑造出新颖有趣,不失活力四射的印象感。荷叶边的裙摆设计,突显出飘逸性感的一面。配合灯笼袖的袖型细节,体现出<UNK>的一面。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#显瘦*材质#水洗*颜色#浅色*风格#休闲*风格#性感*图案#线条*裙型#牛仔裙*裙型#直筒裙*裙下摆#开叉*裙下摆#毛边*裙腰型#高腰", "response": "浅色水洗效果牛仔裙,高腰设计融合修身直筒廓形,凸显纤细腰部和迷人翘臀,美化勾勒性感身材曲线。正面开叉细节有效拉长腿部线条,灵动性感。磨毛边下摆设计,带来休闲随性气息。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#休闲*图案#条纹*图案#印花*衣样式#卫衣*衣款式#连帽*衣款式#罗纹", "response": "这款连帽卫衣自带休闲魅力,将杜嘉班纳的品牌标志以印花的形式装饰在衣身前幅,展现出华丽不失看点的视觉效果,每时每刻都在彰显不凡品味。罗纹条纹袖口和下摆,不仅能使卫衣更帅气惹眼,还能为整体增加一股前卫之风。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#简约*图案#卡通*图案#蝴蝶结*图案#印花*衣样式#衬衫", "response": "大面积的卡通兔子印花,童趣满满,再加上领口的蝴蝶结装饰织带。充满童趣的同时又不失小女生的甜美气息,相当减龄。这款衬衫选用真丝面料,真丝面料不仅轻薄,而且柔滑、亲肤,就好像人的第二层肌肤般带给你清凉舒适的穿着感觉。合身的版型,裁剪得干净利落,简约又不失时尚气息,打造干练的气场。这款衬衫日常十分百搭,不仅可以与其他服饰搭配,作为一件单品也十分出彩。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#牛仔布*材质#水洗*风格#复古*风格#简约*图案#复古*图案#线条*裤长#九分裤*裤款式#不规则*裤口#毛边", "response": "misssixty的这款单品延续经典的九分牛仔裤版型,结合贴合身形的剪裁,展现出柔美修长的腿部线条;不同的位置做了不同程度的水洗复古工艺,使得裤身更加立体厚重;此外,裤脚处采用了微微不规则的毛边剪裁,为简约的整体注入一丝随性之感;再加上<UNK>相互呼应的翅膀状图案点缀,瞬间带来一丝浪漫唯美的味道。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#纯色*图案#纯色*图案#线条*衣样式#卫衣*衣领型#圆领*衣袖型#收口*衣门襟#套头*衣款式#螺纹", "response": "使用经典的螺纹圆领来展开设计,将衣型打造成套头卫衣的款式,穿着时轻松收口,将颈部线条修饰出挺拔优美的的效果,让穿着更加具有精气神。衣身以纯色作为主色调,配上经典的小企鹅logo,将正面点缀,它拥有一个俏皮的小蝴蝶领结,充满细节感,使得衣身吸睛耀眼。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#牛仔布*裤长#九分裤*裤型#直筒裤", "response": "c小小的这样一条迷人的牛仔裤彰显出你的大气个性,它的别致直筒版型十分的高端迷人,让你吸睛十足。个性九分的版型展示出你的迷人小脚踝。它的大气牛仔材质,十分的舒适洒脱,迷人更有型。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#蕾丝*颜色#纯色*风格#简约*图案#纯色*图案#线条*图案#蕾丝*衣样式#衬衫*衣领型#v领", "response": "一款简约的纯色衬衫,采用了个性的大v领,露出柔美的锁骨和颈部线条,散发出清爽迷人的气质;点缀精美的蕾丝花边装饰,波浪形的花边很有美感,增加了视觉亮点。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*图案#撞色*裙下摆#垂坠*裙长#连衣裙*裙袖长#无袖*裙袖型#收口*裙款式#拼接*裙款式#绑带*裙款式#波浪", "response": "来自奥芝国的推出的无袖连衣裙,精选弹力冰丝材质穿柔软垂坠性很好,适合春夏秋三季穿搭。腰部的撞色波浪纹弹力腰封拼接,并以交叉绑带式收口,修饰腰身轻松穿<UNK>人大长腿。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#显瘦*材质#针织*颜色#灰色*颜色#深蓝色*图案#线条*衣样式#毛衣", "response": "这是一款专为胖孩子设计的针织毛衣,加肥加大的立体版型,利落有型穿着合体不臃肿,穿着更加帅气显瘦;领口、袖口和下摆收紧处理使衣衣更加利落有型,久穿久洗也不易磨损和变形,颇具品质感;深蓝色的大身巧妙地加入一些灰色线条修饰活泼大方,孩子穿上它,洋溢着青春活力。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#牛仔布*材质#网纱*风格#街头*衣样式#衬衫*衣款式#拼接*衣款式#勾花镂空*衣款式#钉珠", "response": "时髦又帅气的牛仔拼接裙,利用多材质拼接演绎刚柔并济的设计。硬朗的牛仔衬衫以镂空拼接,构造出深浅的色系变化,加上钉珠铆钉的装饰,更是玩味出十足的街头帅气。下身拼接的网纱半裙,层次细腻又丰富,两侧加入牛仔插袋呼应上身面料,带来一体感设计。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#牛仔布*颜色#蓝色*颜色#浅蓝色*风格#性感*裙型#牛仔裙*裙型#包臀裙*裙下摆#开叉*裙款式#拼接*裙款式#纽扣", "response": "mm们<UNK>拼接风呢?这款牛仔裤是非常有趣的拼接风,浅蓝色和原蓝色的牛仔拼接在一起,非常吸引眼球。在左侧的裙摆处还做了开叉设计,微微露出腿部皮肤,展现性感姿态。包臀的设计,凸显圆润的臀部。前幅一排金属纽扣,增添细节感和精致度。喜欢的mm千万不要错过~"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#蕾丝*颜色#粉色*风格#清新*图案#碎花*图案#线条*图案#蕾丝*裙型#a字*裙下摆#花边*裙领型#圆领*裙款式#飘带", "response": "清新的小碎花缀满衣身,以淡雅的粉色调为底色,焕发出甜美温婉的少女气息。简洁的圆领设计,柔化脸部线条,加上蕾丝飘带点缀,更显娇俏减龄。下摆蕾丝花边分割裙裾,转身间将浪漫挥洒。散开的a字裙摆,恰到好处遮住了臀部和腿部粗的部分,有很好的修饰作用。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#显瘦*风格#淑女*图案#植物*图案#印花*裙型#百褶*裙长#连衣裙*裙领型#娃娃领*裙款式#拼接*裙款式#腰带", "response": "法式浪漫情怀,由这款印花连衣裙为你抒写。蝴蝶花卉印花铺陈裙身,蝴蝶翩跹BRAND花丛,浪漫迷人美如画,法式风情呼之欲出。娃娃领的设计,凸显一身柔美的淑女气质。裙摆百褶的设计,<UNK>飞舞更添灵动飘逸的美。腰带拼接的设计,完美打造显瘦显高的身材比例。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#雪纺*风格#复古*风格#简约*风格#休闲*图案#复古*图案#线条*图案#印花*裙长#连衣裙", "response": "这一款雪纺连衣裙复古的小立领带来不一样的惊喜,不仅拉伸了脖颈的线条,同时衬托出娇小的脸型。衣身大大的印花很有质感,简约休闲中透露着复古精致的美丽。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#文艺*风格#简约*风格#清新*衣样式#外套*衣门襟#拉链*衣款式#拉链", "response": "飘飘落落,暖色的布料上纷纷落落的铺着羽毛,灰常有意境的一款连衣裙。羽毛是这款连衣裙最大的亮点,色彩也丰富饱满,凸显的文艺感也灰常强烈,满满的文艺清新气息;简约大方的设计,有种不喧嚣的热烈感;凸显内敛的气质。搭大衣、棉服外套不仅保暖又灰常的有韵味,而且这款不仅做了开扣的设计,还做了隐形的小拉链!是可哺乳的款式,方便孕后哺乳穿,墙裂推荐!"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#针织*风格#简约*风格#青春*风格#清新*风格#性感*图案#条纹*图案#撞色*裙下摆#开叉*裙长#连衣裙*裙款式#拼接*裙款式#吊带", "response": "这款针织吊带连衣裙展现青春时尚的格调,双侧撞色条纹的拼接简约经典,散发出清新爽朗的气息,显得格外惹眼,营造出明媚动人的视觉吸引力。赋予简约的吊带裙满满的活力,开叉的剪裁性感别致,充满小女人的韵味。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#街头*风格#青春*衣样式#t恤*衣领型#圆领", "response": "三叶草的这款体恤面料比较舒适,穿起来也能很好的透气排汗。整体的设计风格就是经典的款式,所以说是街头常年流行的必备。圆领的领口设计在穿脱时起到了方便。同时修饰脸部轮廓,更显小脸。三叶草的标志也是最为独特的品牌标识,穿出了个人的品味。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#白色*风格#休闲*风格#清新*衣样式#外套*衣款式#连帽", "response": "春天家以清新白色为主基调打造的这款外套,整体采用了直筒的极简剪裁配合休闲感的连帽设计,穿着在身上的舒适度较高。设计师为这款上衣做了<UNK>口的袖子和下摆的处理,穿着后对于身形的修饰效果会更为出众,显得较为得体、大方。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*颜色#军绿色*风格#复古*风格#文艺*风格#知性*风格#休闲*风格#潮*图案#复古*图案#撞色*衣样式#外套*衣样式#西装*衣领型#西装领*衣长#短款*衣袖型#插肩袖", "response": "短款西装小外套,结合了知性和休闲两种风格,在现代的潮流款式中又融入了淡淡的复古韵味。端庄典雅的军绿色衣身,带着自由舒畅的旅行感,款式上选用利落率性色西装领,宽松闲适的插肩袖,门襟选用撞色的两粒扣设计,复古文艺又简洁随性。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*图案#线条*裙款式#勾花镂空*裙款式#收腰", "response": "亮眼的橙红色展示出迎面而来的热情感,衬托肤色白皙红润,在宴会上气场十足。方形的镂空点缀着衣领下方,增加看点散发出小女人的妩媚感。独特的衣袖造型倾斜而下,修饰手臂线条非常修长,在举手投足间优雅又大气。收腰的版型设计修饰腰部线条更纤细,打褶的裙摆在行走时灵动十足,仿佛<UNK>的精灵一般。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*图案#线条*裙款式#勾花镂空*裙款式#收腰", "response": "裙子表面的镂空花网就使其充满了很强的设计美感,首先是肩部将落肩袖和镂空图案相结合,白皙的肌肤隐隐约约,而且能够很好的缩小肩宽比例。v型领口修饰拉长颈部线条和显得脸小。裙子做了收腰裁剪,并将腰线提高,轻松拉长下半身身材比例,裙摆也更加挺括,从而能够解决胯宽等身材烦恼。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#显瘦*材质#涤纶*衣样式#风衣*衣袖型#灯笼袖*衣款式#纽扣*衣款式#飘带", "response": "风衣在摒弃了传统的版型样式,将袖子设计成花苞型的灯笼袖,与春天搭配得恰到好处。并在袖子处装饰了四颗纽扣,采用飘带作为松紧调节,增添层次感更显个性别致。除此之外,风衣采用涤纶材质制成,垂顺感好挺括修身,结合小a字形轮廓,更显身形高挑秀美,并且让矮小个的女性也能撑起风衣的气场。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#英伦*风格#简约*风格#休闲*图案#格子*图案#线条*衣样式#西装*衣领型#翻领*衣门襟#双排扣", "response": "这一款休闲西装简约利落的翻领,可以很好地修饰脸型和颈部线条,显脸小的同时又让脖子看上去更纤细。加上精致的格纹装饰,视觉美丽凸显英伦风。而且双排扣设计,时尚大气美观实用。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#针织*风格#文艺*风格#休闲*风格#性感*裙长#半身裙*裙长#连衣裙*裙款式#拼接", "response": "连衣裙的灵感来自于<UNK>匠人穿着围裙的状态,设计师将针织上衣与半裙结合,整体松软舒适,且不失休闲随性感。裙摆不同材质的拼接,带来丰富的层次细节,让时髦度倍增。偏暗调的配色融入文艺田园气息,显随性姿态。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#复古*图案#复古*裙下摆#荷叶边*裙长#连衣裙", "response": "对于女孩子来说,喜爱连衣裙是与生俱来的!几乎没有问题是一条裙纸<UNK>的~BRAND这款裙子整体的设计有点小复古的感觉,而且艳丽的枣红色也是复古色的代表,上身穿着十分衬肤显白哦。个性而时髦的挂脖式领口露出锁骨很是撩人,另外领口至腰间的衣身前片还加入了很有灵动感的荷叶边作为点缀,瞬间点亮了整体的造型感,由内而外散发的优雅而温柔的气质无人能挡。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#牛仔布*风格#日系*风格#简约*图案#线条", "response": "很喜欢这款简单却不简约的时尚牛仔裤,在夏天可以穿出个性与时尚。整个风格比较偏向于日系的身体,任何妹子都能够轻松驾驭,最重要的是版型。穿上特别修饰腿部的线条,打造出了高挑的身材,让你看起来非常有自信的呢,这手工的工艺凸显出了无限的高级质感。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#街头*风格#潮*裙型#a字", "response": "孕期就一定要穿的沉闷单调吗?热爱潮流的怎能束缚自己个性的心呢,这款裙子采用a字型设计,让你搭配更为轻松随意,飘逸的撞色织带设计,即刻将原本沉闷的空气也带动的活跃起来。从街头到<UNK>,尽显潮流个性时尚。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#牛仔布*颜色#浅蓝色*风格#街头*风格#休闲*裤型#直筒裤*裤款式#破洞", "response": "破洞元素已变成彰显个性的元素,这款浅蓝色牛仔裤糅合磨白磨破设计,弥漫摩登个性格调,而且破洞设计,打破裤装闷热形象,休闲时髦;直筒款巧妙糅合酷帅感与时髦感,塑造街头潮人印象。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*版型#宽松*材质#雪纺*风格#知性*风格#性感*图案#线条*裤长#连体裤*裤款式#木耳边", "response": "雪纺面料的一袭连体裤,舒适的手感,轻盈的穿着,宽松的版型,让上身穿着没有束缚感。一字肩的设计,木耳的花边,显露颈部柔美的线条,与性感的锁骨,展现女性知性的一面,木耳花边的<UNK>设计,显露穿着的甜美感,与少女味。高收腰的设计,拉伸腰部的曲线,提高腰线,显露穿着高挑的身姿。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*风格#简约*图案#线条*裤款式#口袋*裤款式#拉链", "response": "侧缝处添置有立体拉链口袋作为装饰,实用性强且兼备美观性。净色的大体外观,简约低调,大方得体,易于搭配。裤腰处植入张弛有度的弹性带,贴合腰部,适合于大多数人穿着。衣身剪裁干净利落,线条流畅。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#白色*图案#条纹*图案#线条*衣样式#衬衫", "response": "白色的衬衫采用了百褶的袖子设计,既修饰了手臂线条,又为整体增强了设计感。背带裤是永不过时的条纹款式,加上阔腿裤的设计,更显女性身材。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#棉*材质#牛仔布*风格#简约*风格#休闲*裤长#短裤*裤款式#破洞", "response": "选用优质的纯棉面料打造出舒适的质感,而且上身不会扎身。同时,个性破洞细节设计,增加了牛仔短裤的细节感和吸睛度。此外,简约好搭的配色,柔和你的棱角,让你看起来温柔又平易近人。适合约会等休闲场景,是你衣柜里不可或缺的时髦单品之一。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#水洗*风格#潮*裤款式#不规则*裤口#毛边", "response": "年轻潮流的设计品味,洋气又好穿。细节相当丰富有看点,融入水洗磨白,使其充满时尚不羁的气息。裤脚前后毛边处理,配上不规则脚口,更添青春活力。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*颜色#蓝色*风格#简约*裤型#背带裤*裤款式#纽扣", "response": "背带裤的选用天蓝色的主题,远远看上去就像是蓝色<UNK>悬挂在活跃孩子的身上。简约的背带设计,可随时拆开的纽扣,让稚嫩孩子穿衣时不费吹灰之力。腰部更是搭配弹性材料缝制的腰带,不仅方便穿戴而且完美的起到了修饰作用。后背交叉背带,更是独特新颖的处理,更好更牢固的穿搭,不易滑落。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#涤纶*裤款式#拼接*裤款式#口袋", "response": "前置的口袋盖拼接,为本来单调的设计布局增加了亮点,更突出了裤子的个性化特点。加上精致的涤纶梭织面料制作,具备更加亲肤不刺激的丝质般触感,给你带来更加柔软舒适的穿着体验。其良好的透气性,有效提升了裤子的吸湿排汗性能,为你提供更加清爽舒适的体感。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*裙下摆#荷叶边*裙领型#圆领*裙袖型#收口*裙款式#螺纹", "response": "此款上衣采用了经典的圆领款式打造,贴合舒适并能修饰出完美的脸型。同时螺纹的收口贴合身材更完美,在前身处采用了可爱的小狮子造型,带<UNK>真的感觉,而狮子的毛发更是立体精致,显得真实又有丰富的层次。裙身的下摆处采用了荷叶边的设计,俏皮活泼更可爱。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*版型#显瘦*材质#网纱*风格#青春*图案#印花*衣样式#衬衫*衣领型#v领*衣款式#拼接", "response": "这一款衬衫交叠v领的设计,修饰脖颈尽显女人味,宽松的廓形,穿上非常轻松有范毫不拘束,并很好的遮盖身材,非常显瘦。时尚的网纱拼接,自然美感特别出彩。精致印花,青春减龄特别活力。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#运动*风格#性感*衣样式#西装*衣领型#一字领*衣款式#荷叶边", "response": "荷叶边能够表达出女性的优雅,BRAND的这款上衣,将荷叶边很好地运动起来。性感的一字肩设计,荷叶边从一侧手臂的手肘从前胸绕到另一侧,有着前短后长的感觉,自然垂坠很有层次感,举手投足之间,灵动而优雅。西装袖很好地融合,优雅之中透着小帅气。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*风格#运动*风格#休闲*风格#青春*图案#字母*图案#形状*图案#文字*图案#刺绣*图案#撞色*衣样式#卫衣*衣袖型#收口*衣款式#抽绳*衣款式#连帽", "response": "这款dolce&gabbana的连帽卫衣,撞色的<UNK>字母加上桃心形状的刺绣图案令人耳目一新,举手投足间散发阳光活力少女的青春气息;连帽款式尽显帅气利落风范,细节上采用抽绳处理实用又美观,洋溢满满的运动休闲范儿;加之袖口处的收口设计别出心裁,宽松的衣身烘托出慵懒率性的格调。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*版型#宽松*材质#牛仔布*风格#休闲*图案#字母*图案#文字*图案#线条*图案#印花*图案#撞色*裤款式#拼接*裤口#小脚", "response": "上下<UNK>拼接撞色设计,吸睛十足,轻松聚焦视线,个性前卫。字母印花设计,巧添时尚细节看点,以鲜明撞色渲染,展现年轻活力气息。长袖套头轮廓,线条处理恰到好处,呼应休闲基调。宽松的版型,不挑身材,上身好看。连帽的设计美观实用,防风保暖。时尚百搭,可以搭配牛仔裤、紧身裤、休闲裤、束脚裤等。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#显瘦*颜色#灰色*风格#复古*风格#文艺*风格#青春*图案#卡通*图案#复古*衣样式#风衣*衣长#中长款", "response": "一款好看的风衣大概能为这个姹紫嫣红的春天多一份色彩,沉静的灰色上身具有非常好的效果,显得热更加内敛沉稳,有一股淡淡的复古文艺风格。而中长的版型自然下垂,修身显高又瞬间提升气场。后背的卡通图案别致可爱,更添青春气息。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*风格#性感*衣样式#针织衫*衣款式#露肩", "response": "这一款针织衫露肩设计,风情浪漫性感迷人。略微宽松的版型舒适随意,很好的掩饰身材小小的缺陷,看起来精致高挑。加上时尚的花边下摆,错落有致视觉美丽。精致袖口,修饰手臂特别出彩。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*裙款式#松紧带*裙款式#飘带", "response": "<UNK>冷风的气质感,干净利落的feel,小露香肩有一种含蓄撩拨的趣味,袖口领口的飘带设计很是巧妙,让整个小衫更加优美,领子部分的两边肩部松紧带设计,大胆随意的穿出多种效果。让仙女们走在时尚<UNK>的道路上更加自信。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#针织*衣样式#卫衣*衣领型#圆领", "response": "针织卫衣采用了简洁的圆领设计,非常百搭,免去了你<UNK>找不到搭配的烦恼。合体的剪裁设计,让你在跑步健身时轻巧灵便,活动自如,达到更好的锻炼效果。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*颜色#白色*风格#简约*图案#蝴蝶结*图案#刺绣*衣样式#衬衫*衣袖型#喇叭袖", "response": "这一款很好穿的白色衬衫,利落的宽松版型几乎是不挑身材的,无门襟的设计也符合整体的气息。胸前做了绣花的点缀,为简约的衬衫增添了几分柔美的气质。七分的喇叭袖露出小臂,蝴蝶结的点缀显得气质更加的浪漫。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#雪纺*颜色#灰色*风格#英伦*风格#复古*图案#格子*图案#复古*裙型#百褶*裙长#半身裙*裙款式#波浪*裙款式#收腰", "response": "BRAND这款半身裙,用复古的灰色格纹,打造出十足英伦范儿。搭配百褶裙身,为整体增添层次感,穿出减龄风。同时,波浪边的收腰设计,不仅更好的修饰腰部曲线,还为整体气质增添了优雅美感。而雪纺面料,使你在夏日也能穿出清爽感。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#显瘦*风格#复古*风格#文艺*风格#中国风*风格#性感*图案#复古*图案#刺绣*裙型#a字*裙领型#v领", "response": "超级具有中国风气息的一款裙子,带着古典的柔婉。花朵刺绣的运用,色彩缤纷靓丽,冲击视觉,演绎复古文艺范儿。经典的气质v领,既凸显了小性感与时尚,又起到点睛的效果。腰部系的设计,配上a字版型,显瘦又遮肚子。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#清新*风格#性感*图案#线条*衣样式#马甲*衣领型#翻领*衣款式#露背*衣款式#绑带*衣款式#吊带*衣款式#收腰", "response": "小吊带马甲叠穿造型,年轻而不失时尚格调,有着绑带收腰设计,强调出纤细的腰肢,摩登帅气;小翻领露出纤细修长的脖颈线条,散发清爽利落的小清新气息;性感交叉露背设计,别致吸睛,女人味十足;高腰伞形裙摆自然撑开,上身塑造黄金比例,突显得腰更细,巧妙地修饰身型。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#显瘦*颜色#黑白*风格#英伦*风格#简约*图案#格子*图案#线条*衣样式#外套*衣样式#西装*衣门襟#一粒扣", "response": "这款西装外套,版型加长修身,能更好凸显成熟与稳重。细细密密的黑白图案,远远看形成自然的格纹,时髦英伦范儿。平整肩线将线条感拉伸,让身姿显得更挺拔有型。一粒扣设计,简约大气。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#潮*图案#线条*图案#撞色*裙领型#圆领", "response": "采用经典的圆领设计,修饰颈部线条的同时,且上身穿着舒适不易变形,轻松演绎时髦造型。大面积撞色贴花装饰,无疑是点睛之笔,为简洁的款式轮廓带来了更多的视觉层次感与潮流气息。与众不同的你,不在畏惧撞衫的尴尬。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#雪纺*颜色#纯色*风格#清新*图案#纯色*图案#碎花*衣样式#衬衫*衣款式#荷叶边", "response": "这件荷叶边雪纺碎花衬衫和其他的碎花衬衫相比整体的风格会更优雅柔美一些。颜色上也是比较<UNK>清新的花型配色和纯色的大身相结合,会让人看着很舒服,而且每个碎花之间都会<UNK>限视觉上不会觉得太紧密,更有法式的浪漫优雅。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#蕾丝*风格#简约*风格#青春*风格#潮*风格#性感*图案#线条*图案#蕾丝*衣样式#雪纺衫*衣领型#圆领*衣款式#勾花镂空", "response": "这款时尚镂空雪纺衫,带有性感蕾丝工艺,精致百搭的圆领设计,彰显显独特的质感。背部线条流畅,笔挺而有型干练,彰显潮流时尚之风。走线十分笔直,针脚均匀,尽显裁缝之细致。简约时尚的透视蕾丝袖口,彰显成熟又不乏活力的青春气质。青春优雅的独特风格,流露出满满的潮流感。"}
{"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*风格#休闲*裤长#短裤", "response": "来自英国<UNK>的这款儿童休闲短裤,以趣味的小恐龙图案满印裤身,可爱童真,彰显出宝宝的活泼天真范儿。柔软的全棉布料质地,手感细腻顺滑,亲和宝宝的肌肤,带来舒适自在的穿着体验。"}

================================================
FILE: ptuning/README.md
================================================
# ChatGLM-6B-PT
本仓库实现了对于 ChatGLM-6B 模型基于 [P-Tuning v2](https://github.com/THUDM/P-tuning-v2) 的微调。P-Tuning v2 将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 显存即可运行。

下面以 [ADGEN](https://aclanthology.org/D19-1321.pdf) (广告生成) 数据集为例介绍代码的使用方法。

*Read this in [English](README_en.md).

## 软件依赖
运行微调需要4.27.1版本的`transformers`。除 ChatGLM-6B 的依赖之外,还需要安装以下依赖
```
pip install rouge_chinese nltk jieba datasets
```
## 使用方法

### 下载数据集
ADGEN 数据集任务为根据输入(content)生成一段广告词(summary)。

```json
{
    "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",
    "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
}
```

从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载处理好的 ADGEN 数据集,将解压后的 `AdvertiseGen` 目录放到本目录下。

### 训练

#### P-Tuning v2

运行以下指令进行训练:
```shell
bash train.sh
```
`train.sh` 中的 `PRE_SEQ_LEN` 和 `LR` 分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 `quantization_bit` 来被原始模型的量化等级,不加此选项则为 FP16 精度加载。

在默认配置 `quantization_bit=4`、`per_device_train_batch_size=1`、`gradient_accumulation_steps=16` 下,INT4 的模型参数被冻结,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 `per_device_train_batch_size` 的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。

如果你想要[从本地加载模型](../README_en.md#load-the-model-locally),可以将 `train.sh` 中的 `THUDM/chatglm-6b` 改为你本地的模型路径。

#### Finetune

如果需要进行全参数的 Finetune,需要安装 [Deepspeed](https://github.com/microsoft/DeepSpeed),然后运行以下指令:

```shell
bash ds_train_finetune.sh
```

### 推理

在 P-tuning v2 训练时模型只保存 PrefixEncoder 部分的参数,所以在推理时需要同时加载原 ChatGLM-6B 模型以及 PrefixEncoder 的权重,因此需要指定 `evaluate.sh` 中的参数:

```shell
--model_name_or_path THUDM/chatglm-6b
--ptuning_checkpoint $CHECKPOINT_PATH
```

仍然兼容旧版全参保存的 Checkpoint,只需要跟之前一样设定 `model_name_or_path`:

```shell
--model_name_or_path $CHECKPOINT_PATH
```

评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在
`./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`。

### 例子
#### 示例1
* Input: 类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞
* Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。
* Output[微调前]: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。
* Output[微调后]: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。

#### 示例2

* Input: 类型#裙\*版型#显瘦\*风格#文艺\*风格#简约\*图案#印花\*图案#撞色\*裙下摆#压褶\*裙长#连衣裙\*裙领型#圆领
* Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。
* Output[微调前]: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领 1\. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2\. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3\. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4\. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5\. 风格:文艺风格,让连衣裙更加有内涵和品味。6\. 图案:印花设计,在连衣裙上印有独特的图案。7\. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。
* Output[微调后]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。

### 评估结果

|               | Finetune | P-tuning v2 | LoRA |
| ------------- | ----------- | ----- | ------------- |
| BLEU-4        | 8.01    | 8.10 | 7.62 |
| Rouge-1       | 31.23  | 31.12 | 30.60 |
| Rouge-2       | 7.36    | 7.11 | 6.96 |
| Rouge-l       | 25.08  | 24.97 | 24.80 |
| Training Loss | 3.00 | 3.74 | 3.32 |



#### 实验设置

```
max_source_length=64
max_target_length=64
max_steps=3000
```

##### P-tuning v2

```
pre_seq_len=128
learning_rate=2e-2
quantization_bit=4
per_device_train_batch_size=16
gradient_accumulation_steps=1
```

##### Finetune

```
learning_rate=1e-4
fp16
num_gpus=4
per_device_train_batch_size=4
gradient_accumulation_steps=1
```

##### LoRA

实现采用的是 [simple_thu_chatglm6b](https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/simple_thu_chatglm6b)

```
learning_rate=5e-4
per_device_train_batch_size=16
gradient_accumulation_steps=1
```

## 模型部署
首先载入Tokenizer:

```python
from transformers import AutoConfig, AutoModel, AutoTokenizer

# 载入Tokenizer
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
```

1. 如果需要加载的是新 Checkpoint(只包含 PrefixEncoder 参数):

```python
config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
    if k.startswith("transformer.prefix_encoder."):
        new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
```
注意你可能需要将 `pre_seq_len` 改成你训练时的实际值。如果你是[从本地加载模型](https://github.com/THUDM/ChatGLM-6B#%E4%BB%8E%E6%9C%AC%E5%9C%B0%E5%8A%A0%E8%BD%BD%E6%A8%A1%E5%9E%8B)的话,需要将 `THUDM/chatglm-6b` 改成本地的模型路径(注意不是checkpoint路径)。

2. 如果需要加载的是旧 Checkpoint(包含 ChatGLM-6B 以及 PrefixEncoder 参数),或者进行的是全参数微调,则直接加载整个 Checkpoint:

```python
model = AutoModel.from_pretrained(CHECKPOINT_PATH, trust_remote_code=True)
```

之后根据需求可以进行量化,也可以直接使用:

```python
# Comment out the following line if you don't use quantization
model = model.quantize(4)
model = model.half().cuda()
model.transformer.prefix_encoder.float()
model = model.eval()

response, history = model.chat(tokenizer, "你好", history=[])
```

**[23/04/19]** 你也可以直接运行支持加载 P-Tuning v2 checkpoint 的 [web demo](./web_demo.py)
```shell
bash web_demo.sh
```
可能需要修改 [web_demo.sh](./web_demo.sh) 的内容以符合你实际的 checkpoint 情况。

## 使用自己的数据集
修改 `train.sh` 和 `evaluate.sh` 中的 `train_file`、`validation_file`和`test_file`为你自己的 JSON 格式数据集路径,并将 `prompt_column` 和 `response_column` 改为 JSON 文件中输入文本和输出文本对应的 KEY。可能还需要增大 `max_source_length` 和 `max_target_length` 来匹配你自己的数据集中的最大输入输出长度。

## 对话数据集

如需要使用多轮对话数据对模型进行微调,可以提供聊天历史,例如以下是一个三轮对话的训练数据:

```json lines
{"prompt": "长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "response": "用电脑能读数据流吗?水温多少", "history": []}
{"prompt": "95", "response": "上下水管温差怎么样啊?空气是不是都排干净了呢?", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"]]}
{"prompt": "是的。上下水管都好的", "response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"], ["95", "上下水管温差怎么样啊?空气是不是都排干净了呢?"]]}
```

训练时需要指定 `--history_column` 为数据中聊天历史的 key(在此例子中是 `history`),将自动把聊天历史拼接。要注意超过输入长度 `max_source_length` 的内容会被截断。

可以参考以下指令:

```shell
bash train_chat.sh
```

## 引用

```
@inproceedings{liu2022p,
  title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},
  author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},
  booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)},
  pages={61--68},
  year={2022}
}
```





================================================
FILE: ptuning/README_en.md
================================================
# ChatGLM-6B-PT
This repository implements tuning of the ChatGLM-6B model based on [P-Tuning v2](https://github.com/THUDM/P-tuning-v2). P-Tuning v2 reduces the amount of parameters that need to be optimized to 0.1% of the full fine-tuning, and then through model quantization, Gradient Checkpoint and other methods, it only needs a minimum of 7GB of video memory to run.

The following uses the [ADGEN](https://aclanthology.org/D19-1321.pdf) (advertising generation) dataset as an example to introduce how to use the code.

## Software dependencies
Running p-tuning requires version 4.27.1 of `transformers`. In addition to the dependencies of ChatGLM-6B, the following dependencies are required
```
pip install rouge_chinese nltk jieba datasets
```
## Instructions

### Download the dataset
The task of the ADGEN dataset is to generate an advertisement word (summary) based on the input (content).

```json
{
    "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",
    "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
}
```

From [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) download the processed ADGEN dataset, and put the decompressed `AdvertiseGen` directory into this directory.

### Training

#### P-Tuning v2

Run the following commands for training:
```shell
bash train.sh
```
`PRE_SEQ_LEN` and `LR` in `train.sh` are soft prompt length and training learning rate respectively, which can be adjusted to achieve the best results. The P-Tuning-v2 method will freeze all model parameters, and the quantization level of the original model can be adjusted by adjusting `quantization_bit`. If this option is not added, it will be loaded with FP16 precision.

Under the default configuration of `per_device_train_batch_size=1`, `gradient_accumulation_steps=16`, the model parameters of INT4 are frozen, and a training iteration will perform 16 cumulative forward and backward propagations with a batch size of 1, which is equivalent to the total batch size of 16, and only 6.7G GPU memory is required at this time with `quantization_bit=4`. If you want to improve the training efficiency under the same batch size, you can increase the value of `per_device_train_batch_size` while keeping the product of the two unchanged, but it will also bring more GPU memory consumption, please adjust it according to the actual situation.

If you want to [load the model locally](../README_en.md#load-the-model-locally), you can change `THUDM/chatglm-6b` in `train.sh` to your local model path.

#### Finetune
To finetune the full parameters, you need to install [Deepspeed](https://github.com/microsoft/DeepSpeed), and then run the following command:

```shell
bash ds_train_finetune.sh
```

### Inference

During P-tuning v2 training, the model only saves the parameters of the PrefixEncoder part, so the original ChatGLM-6B model and the weight of the PrefixEncoder need to be loaded at the same time during inference, and the arguments need to be specified in `evaluate.sh`:

```shell
--model_name_or_path THUDM/chatglm-6b
--ptuning_checkpoint $CHECKPOINT_PATH
```

It is still compatible with the old version of Checkpoint saved with full parameters, just set `model_name_or_path` as before:

```shell
--model_name_or_path $CHECKPOINT_PATH
```

The evaluation indicators are Chinese Rouge score and BLEU-4. The generated results are saved in
`./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`.

### Example
#### Example 1
* Input: 类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞
* Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。
* Output[before tuning]: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。
* Output[after tuning]: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。

#### Example 2

* Input: 类型#裙\*版型#显瘦\*风格#文艺\*风格#简约\*图案#印花\*图案#撞色\*裙下摆#压褶\*裙长#连衣裙\*裙领型#圆领
* Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。
* Output[before tuning]: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领 1\. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2\. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3\. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4\. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5\. 风格:文艺风格,让连衣裙更加有内涵和品味。6\. 图案:印花设计,在连衣裙上印有独特的图案。7\. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。
* Output[after tuning]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。

### evaluation result

|               | Finetune | P-tuning v2 | LoRA |
| ------------- | ----------- | ----- | ------------- |
| BLEU-4        | 8.01    | 8.10 | 7.62 |
| Rouge-1       | 31.23  | 31.12 | 30.60 |
| Rouge-2       | 7.36    | 7.11 | 6.96 |
| Rouge-l       | 25.08  | 24.97 | 24.80 |
| Training Loss | 3.00 | 3.74 | 3.32 |

#### Experiment Settings

```
max_source_length=64
max_target_length=64
max_steps=3000
```

##### P-tuning v2

```
pre_seq_len=128
learning_rate=2e-2
quantization_bit=4
per_device_train_batch_size=16
gradient_accumulation_steps=1
```

##### Finetune

```
learning_rate=1e-4
fp16
num_gpus=4
per_device_train_batch_size=4
gradient_accumulation_steps=1
```

##### LoRA

The implementation uses [simple_thu_chatglm6b](https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/simple_thu_chatglm6b)

```
learning_rate=5e-4
per_device_train_batch_size=16
gradient_accumulation_steps=1
```

## Model Deployment
First load the tokenizer:

```python
from transformers import AutoConfig, AutoModel, AutoTokenizer

# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
```

1. If a new Checkpoint needs to be loaded (only contains the PrefixEncoder parameter):

```python
config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
     if k.startswith("transformer.prefix_encoder."):
         new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
```
Note that you may need to change `pre_seq_len` to the actual value of your training. If you [load model from local](../README_en.md#load-the-model-locally), you need to change `THUDM/chatglm-6b` to the local model path (not the checkpoint path).

2. If you need to load the old checkpoint (including both ChatGLM-6B and PrefixEncoder parameters), or perform full parameter fine-tuning, then directly load the entire checkpoint:

```python
model = AutoModel.from_pretrained(CHECKPOINT_PATH, trust_remote_code=True)
```

Then it can be quantified according to the needs, or it can be used directly:

```python
# Comment out the following line if you don't use quantization
model = model. quantize(4)
model = model.half().cuda()
model.transformer.prefix_encoder.float()
model = model.eval()

response, history = model.chat(tokenizer, "Hello", history=[])
```

**[23/04/19]** You can also directly run [web demo](./web_demo.py) which supports loading P-Tuning v2 checkpoint
```shell
bash web_demo.sh
```
It may be necessary to modify the content of [web_demo.sh](./web_demo.sh) to match your actual checkpoint situation.

## Use your own dataset
Modify `train_file`, `validation_file` and `test_file` in `train.sh` and `evaluate.sh` to your own JSON format dataset paths, and change `prompt_column` and `response_column` to the keys in the JSON file corresponding to input text and output text.
You may also need to increase `max_source_length` and `max_target_length` to match the maximum input and output lengths in your own dataset.

## Dialog Dataset

If you need to use multiple rounds of dialogue data to train the model, you can provide chat history. For example, the following is the training data for a three-round dialogue:

```json lines
{"prompt": "长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "response": "用电脑能读数据流吗?水温多少", "history": []}
{"prompt": "95", "response": "上下水管温差怎么样啊?空气是不是都排干净了呢?", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"]]}
{"prompt": "是的。上下水管都好的", "response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"], ["95", "上下水管温差怎么样啊?空气是不是都排干净了呢?"]]}
```

During training, you need to specify `--history_column` as the key of the chat history in the data (`history` in this example), and the chat history will be stitched automatically. Note that content exceeding the input length `max_source_length` will be truncated.

You can refer to the following instructions:

```shell
bash train_chat.sh
```

## Citation

```
@inproceedings{liu2022p,
   title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},
   author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},
   booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)},
   pages={61--68},
   year={2022}
}
```

================================================
FILE: ptuning/arguments.py
================================================
from dataclasses import dataclass, field
from typing import Optional


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    ptuning_checkpoint: str = field(
        default=None, metadata={"help": "Path to p-tuning v2 checkpoints"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": (
                "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
                "with private models)."
            )
        },
    )
    resize_position_embeddings: Optional[bool] = field(
        default=None,
        metadata={
            "help": (
                "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
                "the model's position embeddings."
            )
        },
    )
    quantization_bit: Optional[int] = field(
        default=None
    )
    pre_seq_len: Optional[int] = field(
        default=None
    )
    prefix_projection: bool = field(
        default=False
    )


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."})

    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    prompt_column: Optional[str] = field(
        default=None,
        metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
    )
    response_column: Optional[str] = field(
        default=None,
        metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
    )
    history_column: Optional[str] = field(
        default=None,
        metadata={"help": "The name of the column in the datasets containing the history of chat."},
    )
    train_file: Optional[str] = field(
        default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
    )
    validation_file: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
            )
        },
    )
    test_file: Optional[str] = field(
        default=None,
        metadata={
            "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    max_source_length: Optional[int] = field(
        default=1024,
        metadata={
            "help": (
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            )
        },
    )
    max_target_length: Optional[int] = field(
        default=128,
        metadata={
            "help": (
                "The maximum total sequence length for target text after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            )
        },
    )
    val_max_target_length: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "The maximum total sequence length for validation target text after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
                "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
                "during ``evaluate`` and ``predict``."
            )
        },
    )
    pad_to_max_length: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether to pad all samples to model maximum sentence length. "
                "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
                "efficient on GPU but very bad for TPU."
            )
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
        },
    )
    max_predict_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of prediction examples to this "
                "value if set."
            )
        },
    )
    num_beams: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
                "which is used during ``evaluate`` and ``predict``."
            )
        },
    )
    ignore_pad_token_for_loss: bool = field(
        default=True,
        metadata={
            "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
        },
    )
    source_prefix: Optional[str] = field(
        default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
    )

    forced_bos_token: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "The token to force as the first generated token after the decoder_start_token_id."
                "Useful for multilingual models like mBART where the first generated token"
                "needs to be the target language token (Usually it is the target language token)"
            )
        },
    )

    

    def __post_init__(self):
        if self.dataset_name is None and self.train_file is None and self.validation_file is None and self.test_file is None:
            raise ValueError("Need either a dataset name or a training/validation/test file.")
        else:
            if self.train_file is not None:
                extension = self.train_file.split(".")[-1]
                assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
            if self.validation_file is not None:
                extension = self.validation_file.split(".")[-1]
                assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
        if self.val_max_target_length is None:
            self.val_max_target_length = self.max_target_length



================================================
FILE: ptuning/deepspeed.json
================================================
{
  "train_micro_batch_size_per_gpu": "auto",
  "zero_allow_untested_optimizer": true,
  "fp16": {
    "enabled": "auto",
    "loss_scale": 0,
    "initial_scale_power": 16,
    "loss_scale_window": 1000,
    "hysteresis": 2,
    "min_loss_scale": 1
  },
  "zero_optimization": {
    "stage": 2,
    "allgather_partitions": true,
    "allgather_bucket_size": 5e8,
    "overlap_comm": false,
    "reduce_scatter": true,
    "reduce_bucket_size": 5e8,
    "contiguous_gradients" : true
  }
}

================================================
FILE: ptuning/ds_train_finetune.sh
================================================

LR=1e-4

MASTER_PORT=$(shuf -n 1 -i 10000-65535)

deepspeed --num_gpus=4 --master_port $MASTER_PORT main.py \
    --deepspeed deepspeed.json \
    --do_train \
    --train_file AdvertiseGen/train.json \
    --test_file AdvertiseGen/dev.json \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path THUDM/chatglm-6b \
    --output_dir ./output/adgen-chatglm-6b-ft-$LR \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 64 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --predict_with_generate \
    --max_steps 5000 \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate $LR \
    --fp16



================================================
FILE: ptuning/evaluate.sh
================================================
PRE_SEQ_LEN=128
CHECKPOINT=adgen-chatglm-6b-pt-128-2e-2
STEP=3000

CUDA_VISIBLE_DEVICES=0 python3 main.py \
    --do_predict \
    --validation_file AdvertiseGen/dev.json \
    --test_file AdvertiseGen/dev.json \
    --overwrite_cache \
    --prompt_column content \
    --response_column summary \
    --model_name_or_path THUDM/chatglm-6b \
    --ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \
    --output_dir ./output/$CHECKPOINT \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 64 \
    --per_device_eval_batch_size 1 \
    --predict_with_generate \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4


================================================
FILE: ptuning/evaluate_finetune.sh
================================================
CHECKPOINT=adgen-chatglm-6b-ft-1e-4
STEP=3000

CUDA_VISIBLE_DEVICES=0 python3 main.py \
    --do_predict \
    --validation_file AdvertiseGen/dev.json \
    --test_file AdvertiseGen/dev.json \
    --overwrite_cache \
    --prompt_column content \
    --response_column summary \
    --model_name_or_path ./output/$CHECKPOINT/checkpoint-$STEP  \
    --output_dir ./output/$CHECKPOINT \
    --overwrite_output_dir \
    --max_source_length 256 \
    --max_target_length 256 \
    --per_device_eval_batch_size 1 \
    --predict_with_generate \
    --fp16_full_eval


================================================
FILE: ptuning/main.py
================================================
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for sequence to sequence.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.

import logging
import os
import sys
import json

import numpy as np
from datasets import load_dataset
import jieba 
from rouge_chinese import Rouge
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import torch

import transformers
from transformers import (
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    Seq2SeqTrainingArguments,
    set_seed,
)
from trainer_seq2seq import Seq2SeqTrainer

from arguments import ModelArguments, DataTrainingArguments

logger = logging.getLogger(__name__)

def main():

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )

    if training_args.should_log:
        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
        transformers.utils.logging.set_verbosity_info()

    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    # datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Training/evaluation parameters {training_args}")

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # Load dataset
    data_files = {}
    if data_args.train_file is not None:
        data_files["train"] = data_args.train_file
        extension = data_args.train_file.split(".")[-1]
    if data_args.validation_file is not None:
        data_files["validation"] = data_args.validation_file
        extension = data_args.validation_file.split(".")[-1]
    if data_args.test_file is not None:
        data_files["test"] = data_args.test_file
        extension = data_args.test_file.split(".")[-1]

    raw_datasets = load_dataset(
        extension,
        data_files=data_files,
        cache_dir=model_args.cache_dir,
        use_auth_token=True if model_args.use_auth_token else None,
    )

    # Load pretrained model and tokenizer
    config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
    config.pre_seq_len = model_args.pre_seq_len
    config.prefix_projection = model_args.prefix_projection

    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)

    if model_args.ptuning_checkpoint is not None:
        # Evaluation
        # Loading extra state dict of prefix encoder
        model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
        prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
        new_prefix_state_dict = {}
        for k, v in prefix_state_dict.items():
            if k.startswith("transformer.prefix_encoder."):
                new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
        model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
    else:
        model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)

    if model_args.quantization_bit is not None:
        print(f"Quantized to {model_args.quantization_bit} bit")
        model = model.quantize(model_args.quantization_bit)
    if model_args.pre_seq_len is not None:
        # P-tuning v2
        model = model.half()
        model.transformer.prefix_encoder.float()
    else:
        # Finetune
        model = model.float()

    prefix = data_args.source_prefix if data_args.source_prefix is not None else ""

    # Preprocessing the datasets.
    # We need to tokenize inputs and targets.
    if training_args.do_train:
        column_names = raw_datasets["train"].column_names
    elif training_args.do_eval:
        column_names = raw_datasets["validation"].column_names
    elif training_args.do_predict:
        column_names = raw_datasets["test"].column_names
    else:
        logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
        return

    # Get the column names for input/target.
    prompt_column = data_args.prompt_column
    response_column = data_args.response_column
    history_column = data_args.history_column
    
    # Temporarily set max_target_length for training.
    max_target_length = data_args.max_target_length

    def preprocess_function_eval(examples):
        inputs, targets = [], []
        for i in range(len(examples[prompt_column])):
            if examples[prompt_column][i] and examples[response_column][i]:
                query = examples[prompt_column][i]
                if history_column is None or len(examples[history_column][i]) == 0:
                    prompt = query
                else:
                    prompt = ""
                    history = examples[history_column][i]
                    for turn_idx, (old_query, response) in enumerate(history):
                        prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response)
                    prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
                inputs.append(prompt)
                targets.append(examples[response_column][i])

        inputs = [prefix + inp for inp in inputs]
        model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True)
        labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)

        if data_args.ignore_pad_token_for_loss:
            labels["input_ids"] = [
                [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
            ]
        model_inputs["labels"] = labels["input_ids"]

        return model_inputs

    def preprocess_function_train(examples):
        max_seq_length = data_args.max_source_length + data_args.max_target_length

        model_inputs = {
            "input_ids": [],
            "labels": [],
        }
        for i in range(len(examples[prompt_column])):
            if examples[prompt_column][i] and examples[response_column][i]:
                query, answer = examples[prompt_column][i], examples[response_column][i]

                if history_column is None:
                    prompt = query
                else:
                    prompt = ""
                    history = examples[history_column][i]
                    for turn_idx, (old_query, response) in enumerate(history):
                        prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response)
                    prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)

                prompt = prefix + prompt
                a_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
                b_ids = tokenizer.encode(text=answer, add_special_tokens=False)

                if len(a_ids) > data_args.max_source_length - 1:
                    a_ids = a_ids[: data_args.max_source_length - 1]

                if len(b_ids) > data_args.max_target_length - 2:
                    b_ids = b_ids[: data_args.max_target_length - 2]

                input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids)

                context_length = input_ids.index(tokenizer.bos_token_id)
                mask_position = context_length - 1
                labels = [-100] * context_length + input_ids[mask_position+1:]
                
                pad_len = max_seq_length - len(input_ids)
                input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
                labels = labels + [tokenizer.pad_token_id] * pad_len
                if data_args.ignore_pad_token_for_loss:
                    labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]

                model_inputs["input_ids"].append(input_ids)
                model_inputs["labels"].append(labels)

        return model_inputs
    
    def print_dataset_example(example):
        print("input_ids",example["input_ids"])
        print("inputs", tokenizer.decode(example["input_ids"]))
        print("label_ids", example["labels"])
        print("labels", tokenizer.decode(example["labels"]))

    if training_args.do_train:
        if "train" not in raw_datasets:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = raw_datasets["train"]
        if data_args.max_train_samples is not None:
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))
        with training_args.main_process_first(desc="train dataset map pre-processing"):
            train_dataset = train_dataset.map(
                preprocess_function_train,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on train dataset",
            )
        print_dataset_example(train_dataset[0])

    if training_args.do_eval:
        max_target_length = data_args.val_max_target_length
        if "validation" not in raw_datasets:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = raw_datasets["validation"]
        if data_args.max_eval_samples is not None:
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))
        with training_args.main_process_first(desc="validation dataset map pre-processing"):
            eval_dataset = eval_dataset.map(
                preprocess_function_eval,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on validation dataset",
            )
        print_dataset_example(eval_dataset[0])

    if training_args.do_predict:
        max_target_length = data_args.val_max_target_length
        if "test" not in raw_datasets:
            raise ValueError("--do_predict requires a test dataset")
        predict_dataset = raw_datasets["test"]
        if data_args.max_predict_samples is not None:
            max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
            predict_dataset = predict_dataset.select(range(max_predict_samples))
        with training_args.main_process_first(desc="prediction dataset map pre-processing"):
            predict_dataset = predict_dataset.map(
                preprocess_function_eval,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on prediction dataset",
            )
        print_dataset_example(predict_dataset[0])

    # Data collator
    label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        label_pad_token_id=label_pad_token_id,
        pad_to_multiple_of=None,
        padding=False
    )

    # Metric
    def compute_metrics(eval_preds):
        preds, labels = eval_preds
        if isinstance(preds, tuple):
            preds = preds[0]
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        if data_args.ignore_pad_token_for_loss:
            # Replace -100 in the labels as we can't decode them.
            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        score_dict = {
            "rouge-1": [],
            "rouge-2": [],
            "rouge-l": [],
            "bleu-4": []
        }
        for pred, label in zip(decoded_preds, decoded_labels):
            hypothesis = list(jieba.cut(pred))
            reference = list(jieba.cut(label))
            rouge = Rouge()
            scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference))
            result = scores[0]
            
            for k, v in result.items():
                score_dict[k].append(round(v["f"] * 100, 4))
            bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
            score_dict["bleu-4"].append(round(bleu_score * 100, 4))

        for k, v in score_dict.items():
            score_dict[k] = float(np.mean(v))
        return score_dict

    # Override the decoding parameters of Seq2SeqTrainer
    training_args.generation_max_length = (
        training_args.generation_max_length
        if training_args.generation_max_length is not None
        else data_args.val_max_target_length
    )
    training_args.generation_num_beams = (
        data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
    )
    # Initialize our Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics if training_args.predict_with_generate else None,
        save_prefixencoder=model_args.pre_seq_len is not None
    )

    # Training
    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        # elif last_checkpoint is not None:
        #     checkpoint = last_checkpoint
        model.gradient_checkpointing_enable()
        model.enable_input_require_grads()
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        # trainer.save_model()  # Saves the tokenizer too for easy upload

        metrics = train_result.metrics
        max_train_samples = (
            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
        )
        metrics["train_samples"] = min(max_train_samples, len(train_dataset))

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # Evaluation
    results = {}
    max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
        metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=max_seq_length, temperature=0.95)
        max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
        metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))

        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    if training_args.do_predict:
        logger.info("*** Predict ***")
        predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=max_seq_length, do_sample=True, top_p=0.7, temperature=0.95)
        metrics = predict_results.metrics
        max_predict_samples = (
            data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
        )
        metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))

        trainer.log_metrics("predict", metrics)
        trainer.save_metrics("predict", metrics)

        if trainer.is_world_process_zero():
            if training_args.predict_with_generate:
                predictions = tokenizer.batch_decode(
                    predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
                )
                predictions = [pred.strip() for pred in predictions]
                labels = tokenizer.batch_decode(
                    predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
                )
                labels = [label.strip() for label in labels]
                output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
                with open(output_prediction_file, "w", encoding="utf-8") as writer:
                    for p, l in zip(predictions, labels):
                        res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False)
                        writer.write(f"{res}\n")
    return results


def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


if __name__ == "__main__":
    main()


================================================
FILE: ptuning/train.sh
================================================
PRE_SEQ_LEN=128
LR=2e-2

CUDA_VISIBLE_DEVICES=0 python3 main.py \
    --do_train \
    --train_file AdvertiseGen/train.json \
    --validation_file AdvertiseGen/dev.json \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path THUDM/chatglm-6b \
    --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 64 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --max_steps 3000 \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4



================================================
FILE: ptuning/train_chat.sh
================================================
PRE_SEQ_LEN=128
LR=1e-2

CUDA_VISIBLE_DEVICES=0 python3 main.py \
    --do_train \
    --train_file $CHAT_TRAIN_DATA \
    --validation_file $CHAT_VAL_DATA \
    --prompt_column prompt \
    --response_column response \
    --history_column history \
    --overwrite_cache \
    --model_name_or_path THUDM/chatglm-6b \
    --output_dir $CHECKPOINT_NAME \
    --overwrite_output_dir \
    --max_source_length 256 \
    --max_target_length 256 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --max_steps 3000 \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4



================================================
FILE: ptuning/trainer.py
================================================
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""

import contextlib
import functools
import glob
import inspect
import math
import os
import random
import re
import shutil
import sys
import time
import warnings
from collections.abc import Mapping
from distutils.util import strtobool
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

from tqdm.auto import tqdm


# Integrations must be imported before ML frameworks:
# isort: off
from transformers.integrations import (
    default_hp_search_backend,
    get_reporting_integration_callbacks,
    hp_params,
    is_fairscale_available,
    is_optuna_available,
    is_ray_tune_available,
    is_sigopt_available,
    is_wandb_available,
    run_hp_search_optuna,
    run_hp_search_ray,
    run_hp_search_sigopt,
    run_hp_search_wandb,
)

# isort: on

import numpy as np
import torch
import torch.distributed as dist
from huggingface_hub import Repository, create_repo
from packaging import version
from torch import nn
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

from transformers import __version__
from transformers.configuration_utils import PretrainedConfig
from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from transformers.deepspeed import deepspeed_init, is_deepspeed_zero3_enabled
from transformers.dependency_versions_check import dep_version_check
from transformers.modelcard import TrainingSummary
from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
from transformers.optimization import Adafactor, get_scheduler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10, is_torch_less_than_1_11
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_callback import (
    CallbackHandler,
    DefaultFlowCallback,
    PrinterCallback,
    ProgressCallback,
    TrainerCallback,
    TrainerControl,
    TrainerState,
)
from transformers.trainer_pt_utils import (
    DistributedLengthGroupedSampler,
    DistributedSamplerWithLoop,
    DistributedTensorGatherer,
    IterableDatasetShard,
    LabelSmoother,
    LengthGroupedSampler,
    SequentialDistributedSampler,
    ShardSampler,
    distributed_broadcast_scalars,
    distributed_concat,
    find_batch_size,
    get_module_class_from_name,
    get_parameter_names,
    nested_concat,
    nested_detach,
    nested_numpify,
    nested_truncate,
    nested_xla_mesh_reduce,
    reissue_pt_warnings,
)
from transformers.trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
    EvalLoopOutput,
    EvalPrediction,
    FSDPOption,
    HPSearchBackend,
    HubStrategy,
    IntervalStrategy,
    PredictionOutput,
    RemoveColumnsCollator,
    ShardedDDPOption,
    TrainerMemoryTracker,
    TrainOutput,
    default_compute_objective,
    default_hp_space,
    denumpify_detensorize,
    enable_full_determinism,
    find_executable_batch_size,
    get_last_checkpoint,
    has_length,
    number_of_arguments,
    seed_worker,
    set_seed,
    speed_metrics,
)
from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments
from transformers.utils import (
    CONFIG_NAME,
    WEIGHTS_INDEX_NAME,
    WEIGHTS_NAME,
    can_return_loss,
    find_labels,
    get_full_repo_name,
    is_accelerate_available,
    is_apex_available,
    is_datasets_available,
    is_in_notebook,
    is_ipex_available,
    is_sagemaker_dp_enabled,
    is_sagemaker_mp_enabled,
    is_torch_compile_available,
    is_torch_neuroncore_available,
    is_torch_tpu_available,
    logging,
)
from transformers.utils.generic import ContextManagers


_is_native_cpu_amp_available = is_torch_greater_or_equal_than_1_10

DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback

if is_in_notebook():
    from transformers.utils.notebook import NotebookProgressCallback

    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback

if is_apex_available():
    from apex import amp

if is_datasets_available():
    import datasets

if is_torch_tpu_available(check_device=False):
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

if is_fairscale_available():
    dep_version_check("fairscale")
    import fairscale
    from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
    from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
    from fairscale.nn.wrap import auto_wrap
    from fairscale.optim import OSS
    from fairscale.optim.grad_scaler import ShardedGradScaler


if is_sagemaker_mp_enabled():
    import smdistributed.modelparallel.torch as smp
    from smdistributed.modelparallel import __version__ as SMP_VERSION

    IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")

    from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
else:
    IS_SAGEMAKER_MP_POST_1_10 = False


skip_first_batches = None
if is_accelerate_available():
    from accelerate import __version__ as accelerate_version

    if version.parse(accelerate_version) >= version.parse("0.16"):
        from accelerate import skip_first_batches


if TYPE_CHECKING:
    import optuna

logger = logging.get_logger(__name__)


# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"


class Trainer:
    """
    Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.

    Args:
        model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):
            The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.

            <Tip>

            [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use
            your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers
            models.

            </Tip>

        args ([`TrainingArguments`], *optional*):
            The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the
            `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
        data_collator (`DataCollator`, *optional*):
            The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
            default to [`default_data_collator`] if no `tokenizer` is provided, an instance of
            [`DataCollatorWithPadding`] otherwise.
        train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*):
            The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
            `model.forward()` method are automatically removed.

            Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
            distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
            `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
            manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
            sets the seed of the RNGs used.
        eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*):
             The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
             `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
             dataset prepending the dictionary key to the metric name.
        tokenizer ([`PreTrainedTokenizerBase`], *optional*):
            The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the
            maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
            interrupted training or reuse the fine-tuned model.
        model_init (`Callable[[], PreTrainedModel]`, *optional*):
            A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start
            from a new instance of the model as given by this function.

            The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
            be able to choose different architectures according to hyper parameters (such as layer count, sizes of
            inner layers, dropout probabilities etc).
        compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
            The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
            a dictionary string to metric values.
        callbacks (List of [`TrainerCallback`], *optional*):
            A list of callbacks to customize the training loop. Will add those to the list of default callbacks
            detailed in [here](callback).

            If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
        optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple
            containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model
            and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
        preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
            A function that preprocess the logits right before caching them at each evaluation step. Must take two
            tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
            by this function will be reflected in the predictions received by `compute_metrics`.

            Note that the labels (second parameter) will be `None` if the dataset does not have them.

    Important attributes:

        - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
          subclass.
        - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
          original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
          the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
          model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
        - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
          data parallelism, this means some of the model layers are split on different GPUs).
        - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
          to `False` if model parallel or deepspeed is used, or if the default
          `TrainingArguments.place_model_on_device` is overridden to return `False` .
        - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
          in `train`)

    """

    from transformers.trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state

    def __init__(
        self,
        model: Union[PreTrainedModel, nn.Module] = None,
        args: TrainingArguments = None,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
        save_prefixencoder: bool = False,
    ):
        self.save_prefixencoder = save_prefixencoder
        if args is None:
            output_dir = "tmp_trainer"
            logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
            args = TrainingArguments(output_dir=output_dir)
        self.args = args
        # Seed must be set before instantiating the model when using model
        enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
        self.hp_name = None
        self.deepspeed = None
        self.is_in_train = False

        # memory metrics - must set up as early as possible
        self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
        self._memory_tracker.start()

        # set the correct log level depending on the node
        log_level = args.get_process_log_level()
        logging.set_verbosity(log_level)

        # force device and distributed setup init explicitly
        args._setup_devices

        if model is None:
            if model_init is not None:
                self.model_init = model_init
                model = self.call_model_init()
            else:
                raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
        else:
            if model_init is not None:
                warnings.warn(
                    "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
                    " overwrite your model when calling the `train` method. This will become a fatal error in the next"
                    " release.",
                    FutureWarning,
                )
            self.model_init = model_init

        if model.__class__.__name__ in MODEL_MAPPING_NAMES:
            raise ValueError(
                f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only "
                "computes hidden states and does not accept any labels. You should choose a model with a head "
                "suitable for your task like any of the `AutoModelForXxx` listed at "
                "https://huggingface.co/docs/transformers/model_doc/auto."
            )

        if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel:
            self.is_model_parallel = True
        else:
            self.is_model_parallel = False

        # At this stage the model is already loaded
        if getattr(model, "is_loaded_in_8bit", False):
            if getattr(model, "_is_int8_training_enabled", False):
                logger.info(
                    "The model is loaded in 8-bit precision. To train this model you need to add additional modules"
                    " inside the model such as adapters using `peft` library and freeze the model weights. Please"
                    " check "
                    " the examples in https://github.com/huggingface/peft for more details."
                )
            else:
                raise ValueError(
                    "The model you want to train is loaded in 8-bit precision.  if you want to fine-tune an 8-bit"
                    " model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
                )

        # Setup Sharded DDP training
        self.sharded_ddp = None
        if len(args.sharded_ddp) > 0:
            if args.deepspeed:
                raise ValueError(
                    "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
                )
            if len(args.fsdp) > 0:
                raise ValueError(
                    "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
                )

            if args.local_rank == -1:
                raise ValueError("Using sharded DDP only works in distributed training.")
            elif not is_fairscale_available():
                raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
            elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
                raise ImportError(
                    "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
                    f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
                )
            elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.SIMPLE
            elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
            elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_3

        self.fsdp = None
        if len(args.fsdp) > 0:
            if args.deepspeed:
                raise ValueError(
                    "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
                )
            if not args.fsdp_config["xla"] and args.local_rank == -1:
                raise ValueError("Using fsdp only works in distributed training.")

            # dep_version_check("torch>=1.12.0")
            # Would have to update setup.py with torch>=1.12.0
            # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0
            # below is the current alternative.
            if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"):
                raise ValueError("FSDP requires PyTorch >= 1.12.0")

            from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy

            if FSDPOption.FULL_SHARD in args.fsdp:
                self.fsdp = ShardingStrategy.FULL_SHARD
            elif FSDPOption.SHARD_GRAD_OP in args.fsdp:
                self.fsdp = ShardingStrategy.SHARD_GRAD_OP
            elif FSDPOption.NO_SHARD in args.fsdp:
                self.fsdp = ShardingStrategy.NO_SHARD

            self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE
            if "backward_prefetch" in self.args.fsdp_config and "backward_pos" not in self.backward_prefetch:
                self.backward_prefetch = BackwardPrefetch.BACKWARD_POST

            self.forword_prefetch = False
            if self.args.fsdp_config.get("forword_prefect", False):
                self.forword_prefetch = True

            self.limit_all_gathers = False
            if self.args.fsdp_config.get("limit_all_gathers", False):
                self.limit_all_gathers = True

        # one place to sort out whether to place the model on device or not
        # postpone switching model to cuda when:
        # 1. MP - since we are trying to fit a much bigger than 1 gpu model
        # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
        #    and we only use deepspeed for training at the moment
        # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
        # 4. Sharded DDP - same as MP
        # 5. FSDP - same as MP
        self.place_model_on_device = args.place_model_on_device
        if (
            self.is_model_parallel
            or args.deepspeed
            or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
            or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
            or (self.fsdp is not None)
        ):
            self.place_model_on_device = False

        default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
        self.data_collator = data_collator if data_collator is not None else default_collator
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.tokenizer = tokenizer

        if self.place_model_on_device and not getattr(model, "is_loaded_in_8bit", False):
            self._move_model_to_device(model, args.device)

        # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
        if self.is_model_parallel:
            self.args._n_gpu = 1

        # later use `self.model is self.model_wrapped` to check if it's wrapped or not
        self.model_wrapped = model
        self.model = model

        self.compute_metrics = compute_metrics
        self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
        self.optimizer, self.lr_scheduler = optimizers
        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
            raise RuntimeError(
                "Passing a `model_init` is incompatible with providing the `optimizers` argument. "
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
        if is_torch_tpu_available() and self.optimizer is not None:
            for param in self.model.parameters():
                model_device = param.device
                break
            for param_group in self.optimizer.param_groups:
                if len(param_group["params"]) > 0:
                    optimizer_device = param_group["params"][0].device
                    break
            if model_device != optimizer_device:
                raise ValueError(
                    "The model and the optimizer parameters are not on the same device, which probably means you"
                    " created an optimizer around your model **before** putting on the device and passing it to the"
                    " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
                    " `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
                )
        if ((self.sharded_ddp is not None) or args.deepspeed or (self.fsdp is not None)) and (
            self.optimizer is not None or self.lr_scheduler is not None
        ):
            raise RuntimeError(
                "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled."
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
        default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
        callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
        self.callback_handler = CallbackHandler(
            callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
        )
        self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)

        # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
        self._loggers_initialized = False

        # Create clone of distant repo and output directory if needed
        if self.args.push_to_hub:
            self.init_git_repo(at_init=True)
            # In case of pull, we need to make sure every process has the latest.
            if is_torch_tpu_available():
                xm.rendezvous("init git repo")
            elif args.local_rank != -1:
                dist.barrier()

        if self.args.should_save:
            os.makedirs(self.args.output_dir, exist_ok=True)

        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
            raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")

        if args.max_steps > 0:
            logger.info("max_steps is given, it will override any value given in num_train_epochs")

        if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:
            raise ValueError("train_dataset does not implement __len__, max_steps has to be specified")

        if (
            train_dataset is not None
            and isinstance(train_dataset, torch.utils.data.IterableDataset)
            and args.group_by_length
        ):
            raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset")

        self._signature_columns = None

        # Mixed precision setup
        self.use_apex = False
        self.use_cuda_amp = False
        self.use_cpu_amp = False

        # Mixed precision setup for SageMaker Model Parallel
        if is_sagemaker_mp_enabled():
            # BF16 + model parallelism in SageMaker: currently not supported, raise an error
            if args.bf16:
                raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")

            if IS_SAGEMAKER_MP_POST_1_10:
                # When there's mismatch between SMP config and trainer argument, use SMP config as truth
                if args.fp16 != smp.state.cfg.fp16:
                    logger.warning(
                        f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16},"
                        f"but FP16 provided in trainer argument is {args.fp16},"
                        f"setting to {smp.state.cfg.fp16}"
                    )
                    args.fp16 = smp.state.cfg.fp16
            else:
                # smp < 1.10 does not support fp16 in trainer.
                if hasattr(smp.state.cfg, "fp16"):
                    logger.warning(
                        f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
                        "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
                    )

        if args.fp16 or args.bf16:
            if args.half_precision_backend == "auto":
                if args.device == torch.device("cpu"):
                    if args.fp16:
                        raise ValueError("Tried to use `fp16` but it is not supported on cpu")
                    elif _is_native_cpu_amp_available:
                        args.half_precision_backend = "cpu_amp"
                    else:
                        raise ValueError("Tried to use cpu amp but native cpu amp is not available")
                else:
                    args.half_precision_backend = "cuda_amp"

            logger.info(f"Using {args.half_precision_backend} half precision backend")

        self.do_grad_scaling = False
        if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled() or is_torch_tpu_available()):
            # deepspeed and SageMaker Model Parallel manage their own half precision
            if args.half_precision_backend == "cuda_amp":
                self.use_cuda_amp = True
                self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
                #  bf16 does not need grad scaling
                self.do_grad_scaling = self.amp_dtype == torch.float16
                if self.do_grad_scaling:
                    if self.sharded_ddp is not None:
                        self.scaler = ShardedGradScaler()
                    elif self.fsdp is not None:
                        from torch.distributed.fsdp.sharded_grad_scaler import (
                            ShardedGradScaler as FSDPShardedGradScaler,
                        )

                        self.scaler = FSDPShardedGradScaler()
                    elif is_torch_tpu_available():
                        from torch_xla.amp import GradScaler

                        self.scaler = GradScaler()
                    else:
                        self.scaler = torch.cuda.amp.GradScaler()
            elif args.half_precision_backend == "cpu_amp":
                self.use_cpu_amp = True
                self.amp_dtype = torch.bfloat16
            else:
                if not is_apex_available():
                    raise ImportError(
                        "Using FP16 with APEX but APEX is not installed, please refer to"
                        " https://www.github.com/nvidia/apex."
                    )
                self.use_apex = True

        # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
        if (
            is_sagemaker_mp_enabled()
            and self.use_cuda_amp
            and args.max_grad_norm is not None
            and args.max_grad_norm > 0
        ):
            raise ValueError(
                "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
                "along 'max_grad_norm': 0 in your hyperparameters."
            )

        # Label smoothing
        if self.args.label_smoothing_factor != 0:
            self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
        else:
            self.label_smoother = None

        self.state = TrainerState(
            is_local_process_zero=self.is_local_process_zero(),
            is_world_process_zero=self.is_world_process_zero(),
        )

        self.control = TrainerControl()
        # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
        # returned to 0 every time flos need to be logged
        self.current_flos = 0
        self.hp_search_backend = None
        self.use_tune_checkpoints = False
        default_label_names = find_labels(self.model.__class__)
        self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
        self.can_return_loss = can_return_loss(self.model.__class__)
        self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)

        # Internal variables to keep track of the original batch size
        self._train_batch_size = args.train_batch_size

        # very last
        self._memory_tracker.stop_and_update_metrics()

        # torch.compile
        if args.torch_compile and not is_torch_compile_available():
            raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.")

    def add_callback(self, callback):
        """
        Add a callback to the current list of [`~transformer.TrainerCallback`].

        Args:
           callback (`type` or [`~transformer.TrainerCallback`]):
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will instantiate a member of that class.
        """
        self.callback_handler.add_callback(callback)

    def pop_callback(self, callback):
        """
        Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it.

        If the callback is not found, returns `None` (and no error is raised).

        Args:
           callback (`type` or [`~transformer.TrainerCallback`]):
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will pop the first member of that class found in the list of callbacks.

        Returns:
            [`~transformer.TrainerCallback`]: The callback removed, if found.
        """
        return self.callback_handler.pop_callback(callback)

    def remove_callback(self, callback):
        """
        Remove a callback from the current list of [`~transformer.TrainerCallback`].

        Args:
           callback (`type` or [`~transformer.TrainerCallback`]):
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will remove the first member of that class found in the list of callbacks.
        """
        self.callback_handler.remove_callback(callback)

    def _move_model_to_device(self, model, device):
        model = model.to(device)
        # Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
        if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
            model.tie_weights()

    def _set_signature_columns_if_needed(self):
        if self._signature_columns is None:
            # Inspect model forward signature to keep only the arguments it accepts.
            signature = inspect.signature(self.model.forward)
            self._signature_columns = list(signature.parameters.keys())
            # Labels may be named label or label_ids, the default data collator handles that.
            self._signature_columns += list(set(["label", "label_ids"] + self.label_names))

    def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
        if not self.args.remove_unused_columns:
            return dataset
        self._set_signature_columns_if_needed()
        signature_columns = self._signature_columns

        ignored_columns = list(set(dataset.column_names) - set(signature_columns))
        if len(ignored_columns) > 0:
            dset_description = "" if description is None else f"in the {description} set"
            logger.info(
                f"The following columns {dset_description} don't have a corresponding argument in "
                f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
                f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
                " you can safely ignore this message."
            )

        columns = [k for k in signature_columns if k in dataset.column_names]

        if version.parse(datasets.__version__) < version.parse("1.4.0"):
            dataset.set_format(
                type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
            )
            return dataset
        else:
            return dataset.remove_columns(ignored_columns)

    def _get_collator_with_removed_columns(
        self, data_collator: Callable, description: Optional[str] = None
    ) -> Callable:
        """Wrap the data collator in a callable removing unused columns."""
        if not self.args.remove_unused_columns:
            return data_collator
        self._set_signature_columns_if_needed()
        signature_columns = self._signature_columns

        remove_columns_collator = RemoveColumnsCollator(
            data_collator=data_collator,
            signature_columns=signature_columns,
            logger=logger,
            description=description,
            model_name=self.model.__class__.__name__,
        )
        return remove_columns_collator

    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
        if self.train_dataset is None or not has_length(self.train_dataset):
            return None

        generator = None
        if self.args.world_size <= 1:
            generator = torch.Generator()
            # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with
            # `args.seed`) if data_seed isn't provided.
            # Further on in this method, we default to `args.seed` instead.
            if self.args.data_seed is None:
                seed = int(torch.empty((), dtype=torch.int64).random_().item())
            else:
                seed = self.args.data_seed
            generator.manual_seed(seed)

        seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed

        # Build the sampler.
        if self.args.group_by_length:
            if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
                lengths = (
                    self.train_dataset[self.args.length_column_name]
                    if self.args.length_column_name in self.train_dataset.column_names
                    else None
                )
            else:
                lengths = None
            model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
            if self.args.world_size <= 1:
                return LengthGroupedSampler(
                    self.args.train_batch_size * self.args.gradient_accumulation_steps,
                    dataset=self.train_dataset,
                    lengths=lengths,
                    model_input_name=model_input_name,
                    generator=generator,
                )
            else:
                return DistributedLengthGroupedSampler(
                    self.args.train_batch_size * self.args.gradient_accumulation_steps,
                    dataset=self.train_dataset,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
                    lengths=lengths,
                    model_input_name=model_input_name,
                    seed=seed,
                )

        else:
            if self.args.world_size <= 1:
                return RandomSampler(self.train_dataset, generator=generator)
            elif (
                self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
                and not self.args.dataloader_drop_last
            ):
                # Use a loop for TPUs when drop_last is False to have all batches have the same size.
                return DistributedSamplerWithLoop(
                    self.train_dataset,
                    batch_size=self.args.per_device_train_batch_size,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
                    seed=seed,
                )
            else:
                return DistributedSampler(
                    self.train_dataset,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
                    seed=seed,
                )

    def get_train_dataloader(self) -> DataLoader:
        """
        Returns the training [`~torch.utils.data.DataLoader`].

        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
        training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")

        train_dataset = self.train_dataset
        data_collator = self.data_collator
        if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
            train_dataset = self._remove_unused_columns(train_dataset, description="training")
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="training")

        if isinstance(train_dataset, torch.utils.data.IterableDataset):
            if self.args.world_size > 1:
                train_dataset = IterableDatasetShard(
                    train_dataset,
                    batch_size=self._train_batch_size,
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )

            return DataLoader(
                train_dataset,
                batch_size=self._train_batch_size,
                collate_fn=data_collator,
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

        train_sampler = self._get_train_sampler()

        return DataLoader(
            train_dataset,
            batch_size=self._train_batch_size,
            sampler=train_sampler,
            collate_fn=data_collator,
            drop_last=self.args.dataloader_drop_last,
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.dataloader_pin_memory,
            worker_init_fn=seed_worker,
        )

    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
        # Deprecated code
        if self.args.use_legacy_prediction_loop:
            if is_torch_tpu_available():
                return SequentialDistributedSampler(
                    eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
                )
            elif is_sagemaker_mp_enabled():
                return SequentialDistributedSampler(
                    eval_dataset,
                    num_replicas=smp.dp_size(),
                    rank=smp.dp_rank(),
                    batch_size=self.args.per_device_eval_batch_size,
                )
            elif self.args.local_rank != -1:
                return SequentialDistributedSampler(eval_dataset)
            else:
                return SequentialSampler(eval_dataset)

        if self.args.world_size <= 1:
            return SequentialSampler(eval_dataset)
        else:
            return ShardSampler(
                eval_dataset,
                batch_size=self.args.per_device_eval_batch_size,
                num_processes=self.args.world_size,
                process_index=self.args.process_index,
            )

    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
        """
        Returns the evaluation [`~torch.utils.data.DataLoader`].

        Subclass and override this method if you want to inject some custom behavior.

        Args:
            eval_dataset (`torch.utils.data.Dataset`, *optional*):
                If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
                by the `model.forward()` method are automatically removed. It must implement `__len__`.
        """
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
        data_collator = self.data_collator

        if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
            eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")

        if isinstance(eval_dataset, torch.utils.data.IterableDataset):
            if self.args.world_size > 1:
                eval_dataset = IterableDatasetShard(
                    eval_dataset,
                    batch_size=self.args.per_device_eval_batch_size,
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
            return DataLoader(
                eval_dataset,
                batch_size=self.args.eval_batch_size,
                collate_fn=data_collator,
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

        eval_sampler = self._get_eval_sampler(eval_dataset)

        return DataLoader(
            eval_dataset,
            sampler=eval_sampler,
            batch_size=self.args.eval_batch_size,
            collate_fn=data_collator,
            drop_last=self.args.dataloader_drop_last,
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.dataloader_pin_memory,
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
        """
        Returns the test [`~torch.utils.data.DataLoader`].

        Subclass and override this method if you want to inject some custom behavior.

        Args:
            test_dataset (`torch.utils.data.Dataset`, *optional*):
                The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
                `model.forward()` method are automatically removed. It must implement `__len__`.
        """
        data_collator = self.data_collator

        if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
            test_dataset = self._remove_unused_columns(test_dataset, description="test")
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="test")

        if isinstance(test_dataset, torch.utils.data.IterableDataset):
            if self.args.world_size > 1:
                test_dataset = IterableDatasetShard(
                    test_dataset,
                    batch_size=self.args.eval_batch_size,
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
            return DataLoader(
                test_dataset,
                batch_size=self.args.eval_batch_size,
                collate_fn=data_collator,
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

        test_sampler = self._get_eval_sampler(test_dataset)

        # We use the same batch_size as for eval.
        return DataLoader(
            test_dataset,
            sampler=test_sampler,
            batch_size=self.args.eval_batch_size,
            collate_fn=data_collator,
            drop_last=self.args.dataloader_drop_last,
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.dataloader_pin_memory,
        )

    def create_optimizer_and_scheduler(self, num_training_steps: int):
        """
        Setup the optimizer and the learning rate scheduler.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
        `create_scheduler`) in a subclass.
        """
        self.create_optimizer()
        if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:
            # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer
            optimizer = self.optimizer.optimizer
        else:
            optimizer = self.optimizer
        self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)

    def create_optimizer(self):
        """
        Setup the optimizer.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through `optimizers`, or subclass and override this method in a subclass.
        """
        opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model

        if self.optimizer is None:
            decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
            decay_parameters = [name for name in decay_parameters if "bias" not in name]
            optimizer_grouped_parameters = [
                {
                    "params": [
                        p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
                    ],
                    "weight_decay": self.args.weight_decay,
                },
                {
                    "params": [
                        p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
                    ],
                    "weight_decay": 0.0,
                },
            ]

            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
                self.optimizer = OSS(
                    params=optimizer_grouped_parameters,
                    optim=optimizer_cls,
                    **optimizer_kwargs,
                )
            else:
                self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
                if optimizer_cls.__name__ == "Adam8bit":
                    import bitsandbytes

                    manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

                    skipped = 0
                    for module in opt_model.modules():
                        if isinstance(module, nn.Embedding):
                            skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
                            print(f"skipped {module}: {skipped/2**20}M params")
                            manager.register_module_override(module, "weight", {"optim_bits": 32})
                            logger.debug(f"bitsandbytes: will optimize {module} in fp32")
                    print(f"skipped: {skipped/2**20}M params")

        if is_sagemaker_mp_enabled():
            self.optimizer = smp.DistributedOptimizer(self.optimizer)

        return self.optimizer

    @staticmethod
    def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
        """
        Returns the optimizer class and optimizer parameters based on the training arguments.

        Args:
            args (`transformers.training_args.TrainingArguments`):
                The training arguments for the training session.

        """

        # parse args.optim_args
        optim_args = {}
        if args.optim_args:
            for mapping in args.optim_args.replace(" ", "").split(","):
                key, value = mapping.split("=")
                optim_args[key] = value

        optimizer_kwargs = {"lr": args.learning_rate}

        adam_kwargs = {
            "betas": (args.adam_beta1, args.adam_beta2),
            "eps": args.adam_epsilon,
        }
        if args.optim == OptimizerNames.ADAFACTOR:
            optimizer_cls = Adafactor
            optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
        elif args.optim == OptimizerNames.ADAMW_HF:
            from transformers.optimization import AdamW

            optimizer_cls = AdamW
            optimizer_kwargs.update(adam_kwargs)
        elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]:
            from torch.optim import AdamW

            optimizer_cls = AdamW
            optimizer_kwargs.update(adam_kwargs)
            if args.optim == OptimizerNames.ADAMW_TORCH_FUSED:
                optimizer_kwargs.update({"fused": True})
        elif args.optim == OptimizerNames.ADAMW_TORCH_XLA:
            try:
                from torch_xla.amp.syncfree import AdamW

                optimizer_cls = AdamW
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
        elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
            try:
                from apex.optimizers import FusedAdam

                optimizer_cls = FusedAdam
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
        elif args.optim == OptimizerNames.ADAMW_BNB:
            try:
                from bitsandbytes.optim import Adam8bit

                optimizer_cls = Adam8bit
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
        elif args.optim == OptimizerNames.ADAMW_ANYPRECISION:
            try:
                from torchdistx.optimizers import AnyPrecisionAdamW

                optimizer_cls = AnyPrecisionAdamW
                optimizer_kwargs.update(adam_kwargs)

                # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx.
                optimizer_kwargs.update(
                    {
                        "use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")),
                        "momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")),
                        "variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")),
      
Download .txt
gitextract_3917dvmb/

├── .github/
│   └── ISSUE_TEMPLATE/
│       ├── bug_report.yaml
│       ├── config.yml
│       └── feature_request.yml
├── .gitignore
├── FAQ.md
├── LICENSE
├── MODEL_LICENSE
├── PROJECT.md
├── README.md
├── README_en.md
├── UPDATE.md
├── api.py
├── cli_demo.py
├── cli_demo_vision.py
├── improve/
│   ├── README.md
│   └── data_sample.jsonl
├── ptuning/
│   ├── README.md
│   ├── README_en.md
│   ├── arguments.py
│   ├── deepspeed.json
│   ├── ds_train_finetune.sh
│   ├── evaluate.sh
│   ├── evaluate_finetune.sh
│   ├── main.py
│   ├── train.sh
│   ├── train_chat.sh
│   ├── trainer.py
│   ├── trainer_seq2seq.py
│   ├── web_demo.py
│   └── web_demo.sh
├── requirements.txt
├── resources/
│   └── WECHAT.md
├── utils.py
├── web_demo.py
├── web_demo2.py
├── web_demo_old.py
└── web_demo_vision.py
Download .txt
SYMBOL INDEX (106 symbols across 13 files)

FILE: api.py
  function torch_gc (line 11) | def torch_gc():
  function create_item (line 22) | async def create_item(request: Request):

FILE: cli_demo.py
  function build_prompt (line 16) | def build_prompt(history):
  function signal_handler (line 24) | def signal_handler(signal, frame):
  function main (line 29) | def main():

FILE: cli_demo_vision.py
  function build_prompt (line 18) | def build_prompt(history, prefix):
  function signal_handler (line 26) | def signal_handler(signal, frame):
  function main (line 31) | def main():

FILE: ptuning/arguments.py
  class ModelArguments (line 6) | class ModelArguments:
  class DataTrainingArguments (line 65) | class DataTrainingArguments:
    method __post_init__ (line 212) | def __post_init__(self):

FILE: ptuning/main.py
  function main (line 49) | def main():
  function _mp_fn (line 424) | def _mp_fn(index):

FILE: ptuning/trainer.py
  class Trainer (line 222) | class Trainer:
    method __init__ (line 307) | def __init__(
    method add_callback (line 679) | def add_callback(self, callback):
    method pop_callback (line 690) | def pop_callback(self, callback):
    method remove_callback (line 706) | def remove_callback(self, callback):
    method _move_model_to_device (line 717) | def _move_model_to_device(self, model, device):
    method _set_signature_columns_if_needed (line 723) | def _set_signature_columns_if_needed(self):
    method _remove_unused_columns (line 731) | def _remove_unused_columns(self, dataset: "datasets.Dataset", descript...
    method _get_collator_with_removed_columns (line 757) | def _get_collator_with_removed_columns(
    method _get_train_sampler (line 775) | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
    method get_train_dataloader (line 846) | def get_train_dataloader(self) -> DataLoader:
    method _get_eval_sampler (line 896) | def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.u...
    method get_eval_dataloader (line 925) | def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) ...
    method get_test_dataloader (line 975) | def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
    method create_optimizer_and_scheduler (line 1023) | def create_optimizer_and_scheduler(self, num_training_steps: int):
    method create_optimizer (line 1039) | def create_optimizer(self):
    method get_optimizer_cls_and_kwargs (line 1096) | def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any...
    method create_scheduler (line 1186) | def create_scheduler(self, num_training_steps: int, optimizer: torch.o...
    method num_examples (line 1203) | def num_examples(self, dataloader: DataLoader) -> int:
    method _hp_search_setup (line 1217) | def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
    method _report_to_hp_search (line 1258) | def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, ...
    method _tune_save_checkpoint (line 1276) | def _tune_save_checkpoint(self):
    method call_model_init (line 1289) | def call_model_init(self, trial=None):
    method torch_jit_model_eval (line 1303) | def torch_jit_model_eval(self, model, dataloader, training=False):
    method ipex_optimize_model (line 1341) | def ipex_optimize_model(self, model, training=False, dtype=torch.float...
    method _wrap_model (line 1364) | def _wrap_model(self, model, training=True, dataloader=None):
    method train (line 1556) | def train(
    method _inner_training_loop (line 1642) | def _inner_training_loop(
    method _get_output_dir (line 2055) | def _get_output_dir(self, trial):
    method _load_from_checkpoint (line 2075) | def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
    method _load_best_model (line 2133) | def _load_best_model(self):
    method _issue_warnings_after_load (line 2193) | def _issue_warnings_after_load(self, load_result):
    method _maybe_log_save_evaluate (line 2206) | def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignor...
    method _load_rng_state (line 2245) | def _load_rng_state(self, checkpoint):
    method _save_checkpoint (line 2286) | def _save_checkpoint(self, model, trial, metrics=None):
    method _load_optimizer_and_scheduler (line 2392) | def _load_optimizer_and_scheduler(self, checkpoint):
    method hyperparameter_search (line 2449) | def hyperparameter_search(
    method log (line 2545) | def log(self, logs: Dict[str, float]) -> None:
    method _prepare_input (line 2562) | def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torc...
    method _prepare_inputs (line 2580) | def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]])...
    method compute_loss_context_manager (line 2596) | def compute_loss_context_manager(self):
    method autocast_smart_context_manager (line 2602) | def autocast_smart_context_manager(self, cache_enabled: Optional[bool]...
    method training_step (line 2621) | def training_step(self, model: nn.Module, inputs: Dict[str, Union[torc...
    method compute_loss (line 2669) | def compute_loss(self, model, inputs, return_outputs=False):
    method is_local_process_zero (line 2701) | def is_local_process_zero(self) -> bool:
    method is_world_process_zero (line 2708) | def is_world_process_zero(self) -> bool:
    method save_model (line 2720) | def save_model(self, output_dir: Optional[str] = None, _internal_call:...
    method _save_tpu (line 2784) | def _save_tpu(self, output_dir: Optional[str] = None):
    method _save (line 2812) | def _save(self, output_dir: Optional[str] = None, state_dict=None):
    method store_flos (line 2847) | def store_flos(self):
    method _sorted_checkpoints (line 2858) | def _sorted_checkpoints(
    method _rotate_checkpoints (line 2882) | def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
    method evaluate (line 2907) | def evaluate(
    method predict (line 2978) | def predict(
    method evaluation_loop (line 3040) | def evaluation_loop(
    method _nested_gather (line 3251) | def _nested_gather(self, tensors, name=None):
    method _pad_across_processes (line 3269) | def _pad_across_processes(self, tensor, pad_index=-100):
    method prediction_step (line 3303) | def prediction_step(
    method floating_point_ops (line 3408) | def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any...
    method init_git_repo (line 3426) | def init_git_repo(self, at_init: bool = False):
    method create_model_card (line 3473) | def create_model_card(
    method _push_from_checkpoint (line 3529) | def _push_from_checkpoint(self, checkpoint_folder):
    method push_to_hub (line 3571) | def push_to_hub(self, commit_message: Optional[str] = "End of training...
    method prediction_loop (line 3631) | def prediction_loop(
    method _gather_and_numpify (line 3777) | def _gather_and_numpify(self, tensors, name):
    method _add_sm_patterns_to_gitignore (line 3793) | def _add_sm_patterns_to_gitignore(self) -> None:

FILE: ptuning/trainer_seq2seq.py
  class Seq2SeqTrainer (line 30) | class Seq2SeqTrainer(Trainer):
    method evaluate (line 31) | def evaluate(
    method predict (line 80) | def predict(
    method prediction_step (line 138) | def prediction_step(
    method _pad_tensors_to_max_len (line 231) | def _pad_tensors_to_max_len(self, tensor, max_length):

FILE: ptuning/web_demo.py
  function postprocess (line 28) | def postprocess(self, y):
  function parse_text (line 42) | def parse_text(text):
  function predict (line 75) | def predict(input, chatbot, max_length, top_p, temperature, history):
  function reset_user_input (line 84) | def reset_user_input():
  function reset_state (line 88) | def reset_state():
  function main (line 119) | def main():

FILE: utils.py
  function auto_configure_device_map (line 8) | def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
  function load_model_on_gpus (line 38) | def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpu...

FILE: web_demo.py
  function postprocess (line 12) | def postprocess(self, y):
  function parse_text (line 26) | def parse_text(text):
  function predict (line 59) | def predict(input, chatbot, max_length, top_p, temperature, history):
  function reset_user_input (line 68) | def reset_user_input():
  function reset_state (line 72) | def reset_state():

FILE: web_demo2.py
  function get_model (line 13) | def get_model():
  function predict (line 24) | def predict(input, max_length, top_p, temperature, history=None):

FILE: web_demo_old.py
  function predict (line 12) | def predict(input, max_length, top_p, temperature, history=None):

FILE: web_demo_vision.py
  function postprocess (line 12) | def postprocess(self, y):
  function parse_text (line 26) | def parse_text(text):
  function predict (line 59) | def predict(input, image_path, chatbot, max_length, top_p, temperature, ...
  function predict_new_image (line 70) | def predict_new_image(image_path, chatbot, max_length, top_p, temperature):
  function reset_user_input (line 81) | def reset_user_input():
  function reset_state (line 85) | def reset_state():
Condensed preview — 37 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (374K chars).
[
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.yaml",
    "chars": 1789,
    "preview": "name: 🐞 Bug/Help\ndescription: File a bug/issue\ntitle: \"[BUG/Help] <title>\"\nlabels: []\nbody:\n- type: checkboxes\n  attribu"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/config.yml",
    "chars": 27,
    "preview": "blank_issues_enabled: false"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.yml",
    "chars": 763,
    "preview": "name: Feature request\ndescription: Suggest an idea for this project\ntitle: \"[Feature] <title>\"\nlabels: []\nbody:\n- type: "
  },
  {
    "path": ".gitignore",
    "chars": 1972,
    "preview": "# Byte-compiled / optimized / DLL files\r\n__pycache__/\r\n*.py[cod]\r\n*$py.class\r\n\r\n# C extensions\r\n*.so\r\n\r\n# Distribution /"
  },
  {
    "path": "FAQ.md",
    "chars": 668,
    "preview": "## Q1\n\n**Mac直接加载量化后的模型出现提示 `clang: error: unsupported option '-fopenmp'**\n\n这是由于Mac由于本身缺乏omp导致的,此时可运行但是单核。需要单独安装 openmp 依"
  },
  {
    "path": "LICENSE",
    "chars": 11336,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "MODEL_LICENSE",
    "chars": 3063,
    "preview": "The ChatGLM-6B License\n\n1. 定义\n\n“许可方”是指分发其软件的 ChatGLM-6B 模型团队。\n\n“软件”是指根据本许可提供的 ChatGLM-6B 模型参数。(不包括二代模型 ChatGLM2-6B 以及后续模"
  },
  {
    "path": "PROJECT.md",
    "chars": 3116,
    "preview": "# 友情链接\n\n对 ChatGLM 进行加速或者重新实现的开源项目:\n* [lyraChatGLM](https://huggingface.co/TMElyralab/lyraChatGLM): 对 ChatGLM-6B 进行推理加速,最"
  },
  {
    "path": "README.md",
    "chars": 16990,
    "preview": "# ChatGLM-6B\n\n<p align=\"center\">\n   🌐 <a href=\"https://chatglm.cn/blog\" target=\"_blank\">Blog</a> • 🤗 <a href=\"https://hu"
  },
  {
    "path": "README_en.md",
    "chars": 22986,
    "preview": "# ChatGLM-6B\n\n<p align=\"center\">\n   🌐 <a href=\"https://chatglm.cn/blog\" target=\"_blank\">Blog</a> • 🤗 <a href=\"https://hu"
  },
  {
    "path": "UPDATE.md",
    "chars": 5252,
    "preview": "## 更新信息\n**[2023/05/17]** 发布 [VisualGLM-6B](https://github.com/THUDM/VisualGLM-6B),一个支持图像理解的多模态对话语言模型。\n\n![](resources/vis"
  },
  {
    "path": "api.py",
    "chars": 1865,
    "preview": "from fastapi import FastAPI, Request\nfrom transformers import AutoTokenizer, AutoModel\nimport uvicorn, json, datetime\nim"
  },
  {
    "path": "cli_demo.py",
    "chars": 1710,
    "preview": "import os\nimport platform\nimport signal\nfrom transformers import AutoTokenizer, AutoModel\nimport readline\n\ntokenizer = A"
  },
  {
    "path": "cli_demo_vision.py",
    "chars": 1874,
    "preview": "import os\nimport platform\nimport signal\nimport sys\n\nfrom transformers import AutoTokenizer, AutoModel\nimport readline\n\nt"
  },
  {
    "path": "improve/README.md",
    "chars": 1634,
    "preview": "# ChatGLM-6B Badcase 反馈计划\nChatGLM-6B 自3月14号发布以来受到了广大开发者和用户的喜爱,截至4月22号 GitHub 的 star 数达到 2 万,各个渠道模型的累计下载量过 100 万,并连续 12 天"
  },
  {
    "path": "improve/data_sample.jsonl",
    "chars": 20785,
    "preview": "{\"prompt\": \"请根据以下标签为商品编写一段广告\\n类型#裤*版型#宽松*风格#性感*图案#线条*裤型#阔腿裤\", \"response\": \"宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效"
  },
  {
    "path": "ptuning/README.md",
    "chars": 6850,
    "preview": "# ChatGLM-6B-PT\n本仓库实现了对于 ChatGLM-6B 模型基于 [P-Tuning v2](https://github.com/THUDM/P-tuning-v2) 的微调。P-Tuning v2 将需要微调的参数量减少"
  },
  {
    "path": "ptuning/README_en.md",
    "chars": 9341,
    "preview": "# ChatGLM-6B-PT\nThis repository implements tuning of the ChatGLM-6B model based on [P-Tuning v2](https://github.com/THUD"
  },
  {
    "path": "ptuning/arguments.py",
    "chars": 8478,
    "preview": "from dataclasses import dataclass, field\nfrom typing import Optional\n\n\n@dataclass\nclass ModelArguments:\n    \"\"\"\n    Argu"
  },
  {
    "path": "ptuning/deepspeed.json",
    "chars": 509,
    "preview": "{\r\n  \"train_micro_batch_size_per_gpu\": \"auto\",\r\n  \"zero_allow_untested_optimizer\": true,\r\n  \"fp16\": {\r\n    \"enabled\": \"a"
  },
  {
    "path": "ptuning/ds_train_finetune.sh",
    "chars": 766,
    "preview": "\nLR=1e-4\n\nMASTER_PORT=$(shuf -n 1 -i 10000-65535)\n\ndeepspeed --num_gpus=4 --master_port $MASTER_PORT main.py \\\n    --dee"
  },
  {
    "path": "ptuning/evaluate.sh",
    "chars": 660,
    "preview": "PRE_SEQ_LEN=128\nCHECKPOINT=adgen-chatglm-6b-pt-128-2e-2\nSTEP=3000\n\nCUDA_VISIBLE_DEVICES=0 python3 main.py \\\n    --do_pre"
  },
  {
    "path": "ptuning/evaluate_finetune.sh",
    "chars": 562,
    "preview": "CHECKPOINT=adgen-chatglm-6b-ft-1e-4\nSTEP=3000\n\nCUDA_VISIBLE_DEVICES=0 python3 main.py \\\n    --do_predict \\\n    --validat"
  },
  {
    "path": "ptuning/main.py",
    "chars": 18578,
    "preview": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2021 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the "
  },
  {
    "path": "ptuning/train.sh",
    "chars": 753,
    "preview": "PRE_SEQ_LEN=128\nLR=2e-2\n\nCUDA_VISIBLE_DEVICES=0 python3 main.py \\\n    --do_train \\\n    --train_file AdvertiseGen/train.j"
  },
  {
    "path": "ptuning/train_chat.sh",
    "chars": 745,
    "preview": "PRE_SEQ_LEN=128\nLR=1e-2\n\nCUDA_VISIBLE_DEVICES=0 python3 main.py \\\n    --do_train \\\n    --train_file $CHAT_TRAIN_DATA \\\n "
  },
  {
    "path": "ptuning/trainer.py",
    "chars": 185586,
    "preview": "# coding=utf-8\n# Copyright 2020-present the HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (t"
  },
  {
    "path": "ptuning/trainer_seq2seq.py",
    "chars": 11496,
    "preview": "# Copyright 2020 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "ptuning/web_demo.py",
    "chars": 5684,
    "preview": "import os, sys\n\nimport gradio as gr\nimport mdtex2html\n\nimport torch\nimport transformers\nfrom transformers import (\n    A"
  },
  {
    "path": "ptuning/web_demo.sh",
    "chars": 217,
    "preview": "PRE_SEQ_LEN=128\n\nCUDA_VISIBLE_DEVICES=0 python3 web_demo.py \\\n    --model_name_or_path THUDM/chatglm-6b \\\n    --ptuning_"
  },
  {
    "path": "requirements.txt",
    "chars": 96,
    "preview": "protobuf\ntransformers==4.27.1\ncpm_kernels\ntorch>=1.10\ngradio\nmdtex2html\nsentencepiece\naccelerate"
  },
  {
    "path": "resources/WECHAT.md",
    "chars": 193,
    "preview": "<div align=\"center\">\n<img src=wechat.jpg width=\"60%\"/>\n\n<p> 扫码关注公众号,加入「ChatGLM交流群」 </p>\n<p> Scan the QR code to follow t"
  },
  {
    "path": "utils.py",
    "chars": 1874,
    "preview": "import os\nfrom typing import Dict, Tuple, Union, Optional\n\nfrom torch.nn import Module\nfrom transformers import AutoMode"
  },
  {
    "path": "web_demo.py",
    "chars": 3637,
    "preview": "from transformers import AutoModel, AutoTokenizer\nimport gradio as gr\nimport mdtex2html\n\ntokenizer = AutoTokenizer.from_"
  },
  {
    "path": "web_demo2.py",
    "chars": 2127,
    "preview": "from transformers import AutoModel, AutoTokenizer\nimport streamlit as st\nfrom streamlit_chat import message\n\n\nst.set_pag"
  },
  {
    "path": "web_demo_old.py",
    "chars": 1990,
    "preview": "from transformers import AutoModel, AutoTokenizer\nimport gradio as gr\n\ntokenizer = AutoTokenizer.from_pretrained(\"THUDM/"
  },
  {
    "path": "web_demo_vision.py",
    "chars": 4587,
    "preview": "from transformers import AutoModel, AutoTokenizer\nimport gradio as gr\nimport mdtex2html\n\ntokenizer = AutoTokenizer.from_"
  }
]

About this extraction

This page contains the full source code of the zai-org/ChatGLM-6B GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 37 files (352.1 KB), approximately 100.3k tokens, and a symbol index with 106 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!