Full Code of sunzeyeah/RLHF for AI

master cd1a6d54971e cached
135 files
1.4 MB
350.5k tokens
1225 symbols
1 requests
Download .txt
Showing preview only (1,459K chars total). Download the full file or copy to clipboard to get everything.
Repository: sunzeyeah/RLHF
Branch: master
Commit: cd1a6d54971e
Files: 135
Total size: 1.4 MB

Directory structure:
gitextract_138ju87y/

├── README.md
├── docs/
│   └── INSTALL_ON_WINDOWS.md
├── examples/
│   ├── data_prepare.sh
│   ├── eval_pretrain.sh
│   ├── eval_reward.sh
│   ├── pred_pretrain.sh
│   ├── pred_reward.sh
│   ├── pretrain.sh
│   ├── train_dpo.sh
│   ├── train_reward.sh
│   ├── train_rlhf.sh
│   └── train_sft.sh
├── notebook/
│   ├── data_prepare.ipynb
│   ├── reward.ipynb
│   ├── rlhf.ipynb
│   ├── sft.ipynb
│   └── test.ipynb
├── requirements.txt
└── src/
    ├── __init__.py
    ├── data/
    │   ├── __init__.py
    │   ├── data.py
    │   ├── data_types.py
    │   └── pipeline.py
    ├── data_prepare.py
    ├── eval_pretrain.py
    ├── models/
    │   ├── __init__.py
    │   ├── loss.py
    │   ├── ppo.py
    │   ├── reward.py
    │   ├── rlhf_engine.py
    │   └── trainer.py
    ├── pretrain.py
    ├── pretrain_wo_trainer.py
    ├── resources/
    │   ├── config/
    │   │   ├── deepspeed/
    │   │   │   ├── Experiment.md
    │   │   │   ├── README.md
    │   │   │   ├── stage-1.json
    │   │   │   ├── stage-2.json
    │   │   │   ├── stage-3-no_trainer.json
    │   │   │   └── stage-3.json
    │   │   ├── log.conf
    │   │   └── ppo_model/
    │   │       ├── ppo_config_glm-350M-chinese.yml
    │   │       ├── ppo_config_glm-small.yml
    │   │       ├── ppo_config_pangu-2.6B.yml
    │   │       └── ppo_config_pangu-350M.yml
    │   ├── dict/
    │   │   └── jieba_custom_dict.txt
    │   ├── eval/
    │   │   ├── ceval/
    │   │   │   ├── subject_mapping.json
    │   │   │   └── submission_example.json
    │   │   └── mmlu/
    │   │       └── subject_mapping.json
    │   └── models/
    │       ├── baichuan-13B-base/
    │       │   ├── config.json
    │       │   ├── configuration_baichuan.py
    │       │   ├── modeling_baichuan.py
    │       │   ├── tokenization_baichuan.py
    │       │   └── tokenizer_config.json
    │       ├── baichuan-13B-chat/
    │       │   ├── config.json
    │       │   ├── configuration_baichuan.py
    │       │   ├── modeling_baichuan.py
    │       │   ├── tokenization_baichuan.py
    │       │   └── tokenizer_config.json
    │       ├── baichuan-7B/
    │       │   ├── config.json
    │       │   ├── configuration_baichuan.py
    │       │   ├── modeling_baichuan.py
    │       │   ├── tokenization_baichuan.py
    │       │   └── tokenizer_config.json
    │       ├── chatglm-6B/
    │       │   ├── README.md
    │       │   ├── config.json
    │       │   ├── configuration_chatglm.py
    │       │   ├── modeling_chatglm.py
    │       │   ├── quantization.py
    │       │   ├── tokenization_chatglm.py
    │       │   └── tokenizer_config.json
    │       ├── chatglm2-6B/
    │       │   ├── config.json
    │       │   ├── configuration_chatglm.py
    │       │   ├── modeling_chatglm.py
    │       │   ├── quantization.py
    │       │   ├── tokenization_chatglm.py
    │       │   └── tokenizer_config.json
    │       ├── chatglm3-6B/
    │       │   ├── config.json
    │       │   ├── configuration_chatglm.py
    │       │   ├── modeling_chatglm.py
    │       │   ├── quantization.py
    │       │   ├── tokenization_chatglm.py
    │       │   └── tokenizer_config.json
    │       ├── glm-10B-chinese/
    │       │   ├── README.md
    │       │   ├── added_tokens.json
    │       │   ├── config.json
    │       │   ├── configuration_glm.py
    │       │   ├── modeling_glm.py
    │       │   ├── tokenization_glm.py
    │       │   └── tokenizer_config.json
    │       ├── glm-350M-chinese/
    │       │   ├── README.md
    │       │   ├── added_tokens.json
    │       │   ├── config.json
    │       │   ├── configuration_glm.py
    │       │   ├── modeling_glm.py
    │       │   ├── tokenization_glm.py
    │       │   └── tokenizer_config.json
    │       ├── pangu-13B/
    │       │   ├── config.json
    │       │   ├── configuration_gptpangu.py
    │       │   ├── modeling_gptpangu.py
    │       │   ├── tokenization_gptpangu.py
    │       │   └── tokenizer_config.json
    │       ├── pangu-2.6B/
    │       │   ├── config.json
    │       │   ├── configuration_gptpangu.py
    │       │   ├── modeling_gptpangu.py
    │       │   ├── tokenization_gptpangu.py
    │       │   └── tokenizer_config.json
    │       ├── pangu-350M/
    │       │   ├── README-Reward.md
    │       │   ├── README-SFT.md
    │       │   ├── README.md
    │       │   ├── config.json
    │       │   ├── configuration_gptpangu.py
    │       │   ├── modeling_gptpangu.py
    │       │   ├── tokenization_gptpangu.py
    │       │   └── tokenizer_config.json
    │       └── qwen-7B-chat/
    │           ├── config.json
    │           ├── configuration_qwen.py
    │           ├── generation_config.json
    │           ├── modeling_qwen.py
    │           ├── qwen_generation_utils.py
    │           ├── tokenization_qwen.py
    │           └── tokenizer_config.json
    ├── train_dpo.py
    ├── train_reward.py
    ├── train_rlhf.py
    ├── train_rlhf_trlx.py
    ├── train_sft.py
    ├── train_sft_wo_trainer.py
    └── utils/
        ├── __init__.py
        ├── config.py
        ├── file_utils.py
        ├── loading.py
        ├── logger.py
        ├── method_configs.py
        ├── modeling_utils.py
        └── nlp_utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: README.md
================================================
## Features

提供3大功能:
- LLM模型预训练:支持常见模型的预训练,包括:decoder结构(LLaMA、GPT)、encoder结构(GLM)
- LLM模型评测:参考GPT类模型,基于ZeroShot和FewShot实现
- ChatGPT模型训练pipeline:根据[Learning to Summarize from human feedback](https://arxiv.org/abs/2009.01325) ,实现3大流程: SFT、Reward Model和RLHF
  - 支持RLHF阶段 (1) 联合优化reward和policy (2) 单独优化policy,冻结reward 
  - 支持[DPO](https://arxiv.org/abs/2305.18290)作为Reward+RLHF的替代方案,可显著降低显存占用,同时实现RL的效果

## Setup

### 1. Install deepspeed
```bash
git clone https://github.com/microsoft/DeepSpeed.git
cd deepspeed
rm -rf build
TORCH_CUDA_ARCH_LIST="7.0" DS_BUILD_OPS=1 pip install -e . --global-option="build_ext" --global-option="-j8" --no-cache -v --disable-pip-version-check 2>&1 | tee build.log
```
如果想创建binary wheel,方便在其他机器上安装,可使用如下命令,会在```dist```目录生成类似可安装文件```deepspeed-0.3.13+8cd046f-cp38-cp38-linux_x86_64.whl```
```bash
git clone https://github.com/microsoft/DeepSpeed.git
cd deepspeed
rm -rf build
TORCH_CUDA_ARCH_LIST="7.0" DS_BUILD_OPS=1 python setup.py build_ext -j8 bdist_wheel 2>&1 | tee build.log
```
**PS**:需要根据下图,调整```TORCH_CUDA_ARCH_LIST="7.0"```为自己对应的NVIDIA GPU架构
![image info](./images/torch_cuda_list.png "torch_cuda_list")

或运行```torch.cuda.get_device_capability()```获取自己GPU的架构

### 2. Install jieba
在使用Pangu类模型的时候,其special_token格式为```<sep>```、```<pad>```等,而[tokenization_gptpangu.py](src/resources/models/pangu-350M/tokenization_gptpangu.py)中```tokenize()```函数会使用```jieba```进行分词。但直接```pip install jieba```,默认会将```<```和```>```直接切分开,使用```jieba.add_word("<sep>")```也没有作用,因为```jieba```直接hardcode了会自动切分的token,其中就包括了```<```和```>```。 

因此需要执行:
```bash
git clone https://github.com/fxsjy/jieba.git
cd jieba
```
将代码clone到本地,修改```jieba/__init__.py```中```re_han_default```的取值,具体改动如下:

- 改动前:
```python
re_han_default = re.compile("([\u4E00-\u9FD5a-zA-Z0-9+#&\._%\-]+)", re.U)
```

- 改动后:
```python
re_han_default = re.compile("([\u4E00-\u9FD5a-zA-Z0-9+#&\._%\-<>]+)", re.U)
```

修改完成后使用```pip install .```进行本地编译安装,替换原有```jieba```。安装完成后,在代码中加入```jieba.add_word("<sep>")```(该代码已加入[tokenization_gptpangu.py](src/resources/models/pangu-350M/tokenization_gptpangu.py)),即可解决将```<sep>```一类的special token切分为多个id的情况


### 3. Install apex (Optional)
```bash
git clone https://github.com/NVIDIA/apex
cd apex
pip install --global-option="--cpp_ext" --global-option="--cuda_ext" --no-cache -v --disable-pip-version-check .  2>&1 | tee build.log
```
如果想创建binary wheel,方便在其他机器上安装,可使用如下命令,会在```dist```目录生成类似可安装文件```apex-0.0.1+7150e20-cp38-cp38-linux_x86_64.whl```
```bash
git clone https://github.com/NVIDIA/apex
cd apex
python setup.py --cpp_ext --cuda_ext bdist_wheel 2>&1 | tee build.log
```


## Data & Model Download

### 1. 预训练模型下载

| 模型      | size | huggingface地址 | 百度网盘地址  |  提取码      | 
| ----------- | ----------- | ----------- |  ----------- |  ----------- |
| [Pangu-350M](https://openi.pcl.ac.cn/PCL-Platform.Intelligence/PanGu-Alpha)   | 659MB | [sunzeyeah/pangu-350M](https://huggingface.co/sunzeyeah/pangu-350M) | [Pangu-350M](https://pan.baidu.com/s/1IzgtW48S2PKyjxPPMe1rAQ) |  c5jj |
| [Pangu-2.6B](https://openi.pcl.ac.cn/PCL-Platform.Intelligence/PanGu-Alpha)   | 9.8GB | [sunzeyeah/pangu-2_6B](https://huggingface.co/sunzeyeah/pangu-2_6B) | [Pangu-2.6B](https://pan.baidu.com/s/1Tzvja4q_LgQOwkWPQ4jShw)    | 2rad |
| [Pangu-13B](https://openi.pcl.ac.cn/PCL-Platform.Intelligence/PanGu-Alpha)   | 23.6GB | [sunzeyeah/pangu-13B](https://huggingface.co/sunzeyeah/pangu-13B) | [Pangu-13B](https://pan.baidu.com/s/11fWAeYYKqI7pH0UiuJ5jEQ)    | u3dx |
| [GLM-350M-chinese](https://github.com/THUDM/GLM) | 679MB | [sunzeyeah/glm-350M-chinese](https://huggingface.co/sunzeyeah/glm-350M-chinese) | [GLM-350M-chinese](https://pan.baidu.com/s/11Lef-E7Tsz5OGOueCpiqaA) | ii8e |
| [GLM-10B-chinese](https://github.com/THUDM/GLM)   | 18.4G | [sunzeyeah/glm-10B-chinese](https://huggingface.co/sunzeyeah/glm-10B-chinese) | [GLM-10B-chinese](https://pan.baidu.com/s/1GuOefx42n_GzFfwnjoBltw) | fynj  |
| [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B)   | 25.6G | [sunzeyeah/chatglm-6B](https://huggingface.co/sunzeyeah/chatglm-6B) | [ChatGLM-6B](https://pan.baidu.com/s/1OlpkMeQD6-LEpNFWx5E-mg) | uq1k |

**PS**: 本repo提供的预训练模型下载中,
- 对于pytorch_model\*.bin
  - 如果源文件已包括,则不做改动
  - 如果源文件不包括,则根据其提供的checkpoint转换为pytorch_model\*.bin
- 其余文件可能相对原文件有改动,包括:modeling_\*.py、tokenization_\*.py、configuration_\*.py、config.json和tokenizer.config

### 2. 数据下载

| 数据集      | size | huggingface地址 | 百度网盘地址  |  提取码      | 
| ----------- | ----------- | ----------- |  ----------- |  ----------- |
| [CLUE Benchmark](https://github.com/CLUEbenchmark/CLUE)   | 500MB | | [CLUE Benchmark](https://pan.baidu.com/s/15F60nRbBd6d6UvyXdwbXQA) |  m6gt |
| SFT & Reward Data  | 5GB | [sunzeyeah/chinese_chatgpt_corpus](https://huggingface.co/datasets/sunzeyeah/chinese_chatgpt_corpus) | [SFT & Reward Data](https://pan.baidu.com/s/1sl8PB-Dlt1xLIYczMODyRg) | ecyc |
| [百科](https://github.com/brightmart/nlp_chinese_corpus)  | 652MB | | [baike_qa_2019](https://pan.baidu.com/s/1N6I-fvx6FLHwuxJuDLLA8g) | 7jad |
| [知道问答](https://github.com/SophonPlus/ChineseNlpCorpus) | 847MB | | [zhidao](https://pan.baidu.com/s/1sjR3vABiMbdV1HkUt6kCKQ) | neds |
| [对联](https://github.com/wb14123/couplet-dataset/releases/download/1.0/couplet.tar.gz)  | 221MB | | [couplets](https://pan.baidu.com/s/1oo6oaephPXpyU-sDd_37qg) | 54ey |
| [古文](https://github.com/NiuTrans/Classical-Modern)  | 125MB | | [Classical & Modern](https://pan.baidu.com/s/1ZyGkYOgzT9ZEVnpSpB4kQg) | a4cr |
| [古诗词](https://github.com/chinese-poetry/chinese-poetry)  | 87MB | | [chinese poetry](https://pan.baidu.com/s/13uvkA96PdKpKB7ZQ2GkXZQ) | 5zzj |
| 微博新闻评论  | 522MB | | [weibo summary comments](https://pan.baidu.com/s/1h45O0q6gQl3LbH-NjzuRlw) | w0g1 |

**PS**: SFT & Reward Data基于百科、知道问答、对联、古文、古诗词、微博新闻评论数据构造,可直接用于SFT和Reward阶段训练。详见[data_prepare.py](./src/data_prepare.py)


## Usage

### 1. LLM模型预训练
对开源LLM进行增量预训练,基于deepspeed实现。目前支持2类模型架构:
- decoder结构:LLaMA、Baichuan、Pangu
- encoder结构:GLM、ChatGLM

```bash
cd examples
bash pretrain.sh
```

### 2. LLM模型评测
对开源中文LLM进行ZeroShot、OneShot或FewShot的评测。详见[eval_pretrain.py](./src/eval_pretrain.py) 和 [data.py](src/data/data.py)。

目前支持的评测任务:
- [C-Eval](https://github.com/SJTU-LIT/ceval)
- [MMLU](https://github.com/hendrycks/test)
- [CLUEBenchmark](https://github.com/CLUEbenchmark/CLUE) :评测方法和prompt模板参考[Pangu-alpha论文](https://arxiv.org/abs/2104.12369) 

目前支持的开源模型: 
- LLaMA及相关衍生模型 
- ChatGLM(1和2)
- Baichuan
- Qwen
- Pangu
- GLM


```bash
cd examples
bash eval_pretrain.sh
```

### 3. SFT
使用开源LLM + SFT&Reward数据进行SFT训练
```bash
cd examples
bash train_sft.sh
```
### 4. Reward Model
使用SFT模型 + SFT&Reward数据进行Reward模型训练
```bash
cd examples
bash train_reward.sh
```

### 5. RLHF
利用PPO算法和Reward Model,进一步更新SFT模型。基于开源框架[DeepSpeedChat](https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat) 实现
```bash
cd examples
bash train_rlhf.sh
```

### 6. DPO
利用DPO算法替代Reward+RLHF的pipeline,免去训练Reward模型,同时达到RL训练的效果,该方法可显著降低显存占用。基于开源框架[trl](https://github.com/huggingface/trl) 实现
```bash
cd examples
bash train_dpo.sh
```


## Results

### 1. LLM模型评测

<details>
<summary><b>C-Eval 5-shot测试集(test)结果</b></summary>
<table>
    <tr>  <td>Model</td>  <td>Avg</td>  <td>Avg(Hard)</td> <td>STEM</td> <td>Social Science</td> <td>Humanities</td> <td>Other</td> </tr>
    <tr>  <td>Baichuan2-13B-Chat</td>  <td style="color:red"><b>56.30</b></td>  <td>34.20</td>  <td style="color:red"><b>48.20</b></td>  <td style="color:red"><b>70.00</b></td>  <td style="color:red"><b>60.50</b></td>  <td>54.20</td> </tr>    
    <tr>  <td>xverse-13B</td>  <td>55.30</td>  <td>32.50</td>  <td>45.90</td>  <td>66.70</td>  <td>59.50</td>  <td style="color:red"><b>57.60</b></td> </tr>    
    <tr>  <td>Qwen-7B-Chat</td>  <td>54.70</td>  <td>35.40</td>  <td>47.90</td>  <td>68.30</td>  <td>58.70</td>  <td>50.00</td> </tr>    
    <tr>  <td>Baichuan-13B-Base</td>  <td >53.70</td>  <td style="color:red"><b>35.60</b></td>  <td>46.80</td>  <td>65.80</td>  <td>58.00</td>  <td>50.80</td> </tr>    
    <tr>  <td>Baichuan2-7B-Chat</td>  <td>52.50</td>  <td>33.80</td>  <td>45.70</td>  <td>64.20</td>  <td>56.60</td>  <td>50.20</td> </tr>
    <tr>  <td>ChatGLM2-6B</td>  <td>51.20</td>  <td>33.40</td>  <td>46.90</td>  <td>63.00</td>  <td>51.60</td>  <td>47.70</td> </tr>
    <tr>  <td>Baichuan-13B-Chat</td>  <td>47.90</td>  <td>31.50</td>  <td>41.40</td>  <td>56.80</td>  <td>53.00</td>  <td>46.50</td> </tr>
    <tr>  <td>Baichuan-7B</td>  <td>44.20</td>  <td>31.70</td>  <td>39.20</td>  <td>53.30</td>  <td>47.30</td>  <td>41.90</td> </tr>
    <tr>  <td>Ziya-LLaMA-13B-v1.1</td>  <td>40.10</td>  <td>30.30</td>  <td>35.80</td>  <td>47.30</td>  <td>42.80</td>  <td>38.50</td> </tr>
    <tr>  <td>ChatGLM1.1-6B</td>  <td>38.10</td>  <td>28.60</td>  <td>33.60</td>  <td>46.70</td>  <td>40.90</td>  <td>35.70</td> </tr>
    <tr>  <td>AtomGPT-13B-56k</td>  <td>37.60</td>  <td>25.30</td>  <td>32.00</td>  <td>44.70</td>  <td>42.80</td>  <td>36.10</td> </tr>
    <tr>  <td>LLaMA2-13B-chat</td>  <td>37.10</td>  <td>29.30</td>  <td>34.60</td>  <td>43.60</td>  <td>35.90</td>  <td>37.00</td> </tr>
    <tr>  <td>ChatGLM-6B</td>  <td>36.30</td>  <td>27.20</td>  <td>32.90</td>  <td>42.80</td>  <td>38.10</td>  <td>34.90</td> </tr>
    <tr>  <td>LLaMA-30B</td>  <td>35.90</td>  <td>29.90</td>  <td>34.40</td>  <td>42.40</td>  <td>33.30</td>  <td>35.60</td> </tr>
    <tr>  <td>LLaMA2-7B-chat</td>  <td>33.50</td>  <td>27.30</td>  <td>31.60</td>  <td>38.10</td>  <td>33.80</td>  <td>32.70</td> </tr>
    <tr>  <td>Ziya-LLaMA-13B-Pretrain-v1</td>  <td>31.10</td>  <td>22.20</td>  <td>27.40</td>  <td>36.50</td>  <td>33.80</td>  <td>30.40</td> </tr>
    <tr>  <td>LLaMA-13B</td>  <td>29.8</td>  <td>24.20</td>  <td>28.40</td>  <td>33.70</td>  <td>29.60</td>  <td>29.00</td> </tr>
    <tr>  <td>LLaMA-7B</td>  <td>26.80</td>  <td>26.70</td>  <td>26.20</td>  <td>27.60</td>  <td>25.70</td>  <td>28.10</td> </tr>
</table>
</details>

<details>
<summary><b>MMLU 5-shot测试集(test)结果</b></summary>
<table>
    <tr>  <td>Model</td>  <td>Avg</td>  <td>STEM</td> <td>Social Science</td> <td>Humanities</td> <td>Other</td> </tr>
    <tr>  <td>Baichuan2-13B-Chat</td>  <td style="color:red"><b>56.90</b></td>  <td style="color:red"><b>47.28</b></td>  <td style="color:red"><b>66.23</b></td>  <td>52.90</td>  <td style="color:red"><b>63.50</b></td> </tr>
    <tr>  <td>LLaMA-30B</td>  <td>56.33</td>  <td>44.68</td>  <td>65.64</td>  <td style="color:red"><b>54.60</b></td>  <td>61.57</td> </tr>
    <tr>  <td>xverse-13B</td>  <td>55.24</td>  <td>45.60</td>  <td>64.51</td>  <td>50.32</td>  <td>63.27</td> </tr>
    <tr>  <td>Qwen-7B-Chat</td>  <td>54.13</td>  <td>41.76</td>  <td>63.43</td>  <td>50.81</td>  <td>62.50</td> </tr>
    <tr>  <td>LLaMA2-13B-chat</td>  <td>53.98</td>  <td>44.52</td>  <td>63.40</td>  <td>49.37</td>  <td>61.21</td> </tr>
    <tr>  <td>Baichuan-13B-Base</td>  <td>53.46</td>  <td>43.86</td>  <td>63.14</td>  <td>49.73</td>  <td>59.28</td> </tr>
    <tr>  <td>Baichuan2-7B-Chat</td>  <td>53.11</td>  <td>43.51</td>  <td>62.26</td>  <td>49.58</td>  <td>59.12</td> </tr>
    <tr>  <td>Baichuan-13B-Chat</td>  <td>51.12</td>  <td>41.61</td>  <td>59.11</td>  <td>47.52</td>  <td>58.31</td> </tr>
    <tr>  <td>Ziya-LLaMA-13B-v1.1</td>  <td>51.06</td>  <td>41.89</td>  <td>57.71</td>  <td>49.22</td>  <td>56.54</td> </tr>
    <tr>  <td>LLaMA2-7B-chat</td>  <td>48.10</td>  <td>39.64</td>  <td>56.28</td>  <td>43.61</td>  <td>55.39</td> </tr>
    <tr>  <td>LLaMA-13B</td>  <td>46.51</td>  <td>37.23</td>  <td>52.71</td>  <td>44.35</td>  <td>53.04</td> </tr>
    <tr>  <td>ChatGLM2-6B</td>  <td>45.83</td>  <td>38.75</td>  <td>52.06</td>  <td>43.20</td>  <td>50.82</td> </tr>
    <tr>  <td>AtomGPT-13B-56k</td>  <td>42.75</td>  <td>36.02</td>  <td>49.04</td>  <td>38.80</td>  <td>49.30</td> </tr>
    <tr>  <td>Baichuan-7B</td>  <td>41.96</td>  <td>36.63</td>  <td>47.77</td>  <td>37.55</td>  <td>48.31</td> </tr>
    <tr>  <td>Ziya-LLaMA-13B-Pretrain-v1</td>  <td>41.61</td>  <td>33.61</td>  <td>46.01</td>  <td>39.85</td>  <td>48.05</td> </tr>
    <tr>  <td>ChatGLM1.1-6B</td>  <td>40.07</td>  <td>32.95</td>  <td>44.55</td>  <td>39.23</td>  <td>44.12</td> </tr>
    <tr>  <td>ChatGLM-6B</td>  <td>37.87</td>  <td>32.41</td>  <td>43.80</td>  <td>35.60</td>  <td>41.00</td> </tr>
    <tr>  <td>LLaMA-7B</td>  <td>28.53</td>  <td>26.10</td>  <td>28.76</td>  <td>28.52</td>  <td>24.81</td> </tr>
</table>
</details>

<details>
<summary><b>CLUEBenchmark 验证集(dev.json)结果</b></summary>

<table>
    <tr>  <td rowspan="2">Dataset</td>  <td rowspan="2">Method</td>  <td rowspan="2">Metrics</td>  <td rowspan="2">Task Type</td>  <td colspan="5" style="text-align:center">Zero-shot</td>  <td colspan="5" style="text-align:center">Few-shot</td> </tr>
    <tr>  <td>GLM-350M-chinese</td>  <td>Pangu-350M</td>  <td>Pangu-2.6B</td>  <td>GLM-10B-chinese</td>  <td>Pangu-13B</td>  <td>GLM-350M-chinese</td>  <td>Pangu-350M</td>  <td>Pangu-2.6B</td>  <td>GLM-10B-chinese</td>  <td>Pangu-13B</td> </tr>
    <tr>  <td>OCNLI</td>  <td>PPL</td>  <td>acc</td>  <td>NLI</td>  <td>0.3074</td>  <td style="color:red"><b>0.3369</b></td>  <td>0.3061</td>  <td>0.3288</td>  <td>0.3301</td>  <td>0.3298</td>  <td>0.3352</td>  <td>0.3216</td>  <td></td>  <td></td> </tr>
    <tr>  <td>CMNLI</td>  <td>PPL</td>  <td>acc</td>  <td>NLI</td>  <td>0.3279</td>  <td>0.3302</td>  <td>0.3310</td>  <td>0.3338</td>  <td style="color:red"><b>0.3358</b></td>  <td>0.3356</td>  <td>0.3328</td>  <td>0.3300</td>  <td></td>  <td></td> </tr>
    <tr>  <td>CHID</td>  <td>PPL</td>  <td>acc</td>  <td>Cloze(multi-choices)</td>  <td>0.0734</td>  <td>0.0916</td>  <td>0.0670</td>  <td>0.1016</td>  <td style="color:red"><b>0.1018</b></td>  <td>0.0979</td>  <td>0.1007</td>  <td>0.0996</td>  <td></td>  <td></td> </tr>
    <tr>  <td>CMRC2018</td>  <td>generation</td>  <td>f1</td>  <td>MRC</td>  <td>0.093</td>  <td>0.0979</td>  <td>0.1007</td>  <td style="color:red"><b>0.1392</b></td>  <td>0.021</td>  <td>0.09345</td>  <td>0.097</td>  <td>0.1007</td>  <td></td>  <td></td> </tr>
    <tr>  <td>CLUEWSC2020</td>  <td>PPL</td>  <td>acc</td>  <td>WSC</td>  <td>0.4934</td>  <td>0.5328</td>  <td style="color:red"><b>0.5592</b></td>  <td>0.5131</td>  <td>0.4671</td>  <td>0.5526</td>  <td>0.4473</td>  <td>0.4671</td>  <td></td>  <td></td> </tr>
    <tr>  <td>C3</td>  <td>PPL</td>  <td>acc</td>  <td>Common sense reasoning</td>  <td>0.2360</td>  <td>0.2426</td>  <td>0.2418</td>  <td style="color:red"><b>0.2573</b></td>  <td>0.2567</td>  <td>0.2476</td>  <td>0.2559</td>  <td>0.2515</td>  <td></td>  <td></td> </tr>
    <tr>  <td>AFQMC</td>  <td>PPL</td>  <td>acc</td>  <td>Text classification</td>  <td style="color:red"><b>0.6306</b></td>  <td>0.4582</td>  <td>0.4914</td>  <td>0.4960</td>  <td>0.5000</td>  <td>0.4872</td>  <td>0.4993</td>  <td>0.5018</td>  <td></td>  <td></td> </tr>
    <tr>  <td>CSL</td>  <td>PPL</td>  <td>acc</td>  <td>Text classification</td>  <td>0.4943</td>  <td>0.4913</td>  <td>0.4666</td>  <td style="color:red"><b>0.5126</b></td>  <td>0.4996</td>  <td>0.5140</td>  <td>0.5036</td>  <td>0.4973</td>  <td></td>  <td></td> </tr>
    <tr>  <td>IFLYTEK</td>  <td>PPL</td>  <td>acc</td>  <td>Text classification</td>  <td>0.1292</td>  <td style="color:red"><b>0.3058</b></td>  <td>0.265</td>  <td>0.2620</td>  <td>0.2408</td>  <td>0.2539</td>  <td>0.2535</td>  <td>0.2524</td>  <td></td>  <td></td> </tr>
    <tr>  <td>TNEWS</td>  <td>PPL</td>  <td>acc</td>  <td>Text classification</td>  <td>0.1582</td>  <td>0.2022</td>  <td>0.2449</td>  <td>0.2489</td>  <td style="color:red"><b>0.2527</b></td>  <td>0.2555</td>  <td>0.2466</td>  <td>0.2494</td>  <td></td>  <td></td> </tr>
</table>
</details>

### 2. SFT

模型训练参数:

| 模型 | 可训练参数量 | 数据量 | batch size | sequence length | 硬件 | 显存占用 | speed | Hours per epoch |
| --- | --- | --- | --- | :---: | :---: | :---: | --- | --- |
| GLM-350M-chinese | 355M | 5.4M | 4 | 512 | V100 16G | 13G | 3.7 s/iter | 88h |
| Pangu-350M | 345M | 5.4M | 48 | 512 | A100 80G | 78G | 1.91 s/iter | 27.5h |
| Pangu-2.6B | 2.6B | 5.4M |  8 | 512 | A100 80G | 79.4G | 9.61 s/iter | 116h |

SFT模型下载:

| 预训练模型  |  SFT数据集  | size | huggingface地址 | 百度网盘地址  |  提取码      | 
| ----------- | --- | ----------- | ----------- |  ----------- | ----------- |
| Pangu-350M |  SFT & Reward Data | 1.3GB | [sunzeyeah/pangu-350M-sft](https://huggingface.co/sunzeyeah/pangu-350M-sft) | [Pangu-350M-SFT](https://pan.baidu.com/s/14nF63nAQz38jmplUC0EQ-g) | 0f8j |
| Pangu-2.6B |  SFT & Reward Data | 9.8GB | [sunzeyeah/pangu-2_6B-sft](https://huggingface.co/sunzeyeah/pangu-2_6B-sft) | [Pangu-2.6B-SFT](https://pan.baidu.com/s/1Q2PKf0MnFsiFlNhU-6WIrg) | r13w |


<details>
<summary><b>Pangu-2.6B-SFT生成示例</b></summary>

```
# 生成参数
max_length = 512
max_length_generation = 100
top_k = 10
temperature = 0.8
do_sample=True
```
```
[Input] 我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点<sep>回答:
[Generated] 你爱他就多关心他,少在他面前表现出你的脆弱。这样他才更会爱你。


[Input] 一次次买彩,一次次倾听那比分,一次次兴奋,一次次失落,一次次无归而返!最终还是任8~~~~<sep>回答:
[Generated] 我也是这样,买彩的时间也不长,但是每次都是这样。。。。。。。

[Input] 现代文:太祖很高兴地说:凭着这样的部下去作战,还担心不能取胜么?<sep>古文:
[Generated] 太祖悦曰:此自是以去,何忧不胜?

[Input] 现代文:十一年秋,丞相伯颜与阿术会师襄阳,遣阿剌罕率诸翼军攻郢、复诸州,十月夺取郢州之南门堡。<sep>古文:
[Generated] 十一年秋,丞相伯颜与阿术会师襄阳,遣阿剌罕率诸翼军攻郢、复诸州,十月夺郢州之南门堡。

[Input] 以孔⁇的风格,写一篇经书,题为中庸<sep>
[Generated] 我闻昔者夫子之言也,曰:吾道一以贯之。其或行之或不行,吾未知之。是故君子慎其独也。吾未见好德如好色者也。好德如好色者,未之有也。好德如好色者,既已之有也。君子不以言举人,不以人废言。好德如好色者,其如命何!吾

[Input] 以陆游的风格,写一首七言诗,题为晨起<sep>
[Generated] 秋光满纸清⁇明,睡醒一江万马奔。自笑身如枯木鹤,何妨食饱作闲人。

[Input] 上联:霜风几笔雕残画<sep>下联:
[Generated] 云海千山叠翠峰

[Input] 新闻内容:据悉,这名14岁的女生在Facebook上发布了戏谑考试的博文,其中使用了胡志明1946年反对法国殖民者的讲话,随后被越南当局停课一年。越南官员称,该学生歪曲了历史事件,侮辱了老师的情感。多数网友虽对越南当局做法不满,但都表示,至少他们还能上Facebook.<sep>摘要:越南女生Facebook上模仿胡志明讲话被停课<sep>评论:
[Generated] 这女的,还真有勇气,我喜欢你!

[Input] 新闻内容:习近平天津考察时勉励大学生志存高远、脚踏实地,转变择业观念,坚持从实际出发,勇于到基层一线和艰苦地方去,把人生的路一步步走稳走实,善于在平凡岗位上创造不平凡的业绩<sep>摘要:习近平勉励大学生转变择业观念勇于到基层和艰苦地方去<sep>评论:
[Generated] 现在的社会,很多年轻人都是在父母的安排下,去了不同的单位,在不同的地方工作,在不同的城市生活,这就导致了年轻人的思维方式,价值观都不一样,这也导致了年轻人的成长道路,也有很多的不安全感,这也是很正常的事情,年轻人需要正确的引导,正确的引导,也需要正确的引导!

[Input] 问题:劳动仲裁申请法院强制执行,法院强制执行需要多长时间<sep>回答:
[Generated] 一般申请法院强制执行需要3-6个月时间,具体需要看法院的执行措施。

[Input] 问题:银行贷款50万,30年!年利率是4.41怎么算一个月还多少钱<sep>回答:
[Generated] 若申请的是招行贷款,要计算贷款利息或者每月还款金额,需要知道贷款本金,贷款期限,还款方式,贷款年利率。若上述信息都确认到,可以通过我行贷款计算器尝试计算,登录招行官网右下方找到“理财计算器”--“个人贷款计算器”可以计算。
```
</details>


### 3. Reward Model

模型训练参数:

| 模型 | 可训练参数量 | 数据量 | batch size | sequence length | 硬件 | 显存占用 | speed | Hours per epoch |
| --- | --- | --- | --- | :---: | :---: | :---: | --- | --- |
| Pangu-350M | 131M | 12M | 32 | 512 | A100 80G | 72.6G | 1.91 s/iter | 105h |
| Pangu-2.6B | 815M | 12M |  8 | 512 | A100 80G | 80.7G |  | 423h |


Reward模型下载:

| 预训练模型  |  SFT数据集  | size | huggingface地址 | 百度网盘地址  |  提取码      | 
| ----------- | --- | ----------- | ----------- |  ----------- | ----------- |
| Pangu-350M |  SFT & Reward Data | 1.3GB | [sunzeyeah/pangu-350M-reward](https://huggingface.co/sunzeyeah/pangu-350M-reward) | [Pangu-350M-Reward](https://pan.baidu.com/s/1wC3w78t7pVn0Xn5tJHy06A) | 4gju |


### 4. RLHF

To be updated

### 5. DeepSpeed实验

为验证不同预训练模型使用deepspeed的训练效率是否能达到官方宣称的效果(加速、节省GPU等),进行了benchmarking
- 实验场景:SFT阶段训练
- 实验参数:```max_sequence_length=512```

<details>
<summary><b>DeepSpeed实验结果</b></summary>
<table>
   <tr> <td>模型</td> <td>数据</td>  <td>整体耗时/epoch</td>  <td>单条样本耗时</td>  <td>内存使用量</td>  <td>显存使用量</td>  <td>GPU型号和数量</td> <td>fp16</td> <td>bf16</td> <td>deepspeed stage</td> <td>offload optimizer</td> <td>pin memory</td> <td>offloard param</td> <td>overlap comm</td> <td>allgather bucket size</td> <td>stage3 max live parameters</td> <td>batch size</td> <td>gradient accumulation steps</td> <td>gradient checkpointing</td> <td>model half</td> </tr>
   <tr> <td rowspan="11">T5-large</td> <td rowspan="11">wmt16-en-ro, 共计61万条样本</td> <td>43h</td>  <td>0.5s/it</td>  <td>7.1G</td>  <td>1*14529MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>-</td>  <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>
   <tr> <td>152h</td>  <td>1.78s/it</td>  <td>38.26G</td>  <td>1*11663MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>2e8</td> <td>-</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>
   <tr> <td>250h</td>  <td>2.95s/it</td>  <td>38.74G</td>  <td>1*7255MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>1e5</td> <td>-</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>
   <tr> <td>62h</td>  <td>5.8s/it</td>  <td>86.81G</td>  <td>8*7811MB</td>  <td>8*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>1e5</td> <td>-</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>
   <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>2e8</td> <td>-</td> <td>16</td> <td>8</td> <td>false</td> <td>false</td> </tr>
   <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>1e5</td> <td>-</td> <td>16</td> <td>8</td> <td>false</td> <td>false</td> </tr>
   <tr> <td>290h</td>  <td>3.48s/it</td>  <td>46.53G</td>  <td>1*6655MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>2e8</td> <td>2e8</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>
   <tr> <td>380h</td>  <td>4.5s/it</td>  <td>43.48G</td>  <td>1*5263MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e5</td> <td>1e5</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>
   <tr> <td>215h</td>  <td>4.9s/it</td>  <td>47.31G</td>  <td>2*5019MB</td>  <td>2*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e5</td> <td>1e5</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>
   <tr> <td>1370h</td>  <td>64s/it</td>  <td>57.55G</td>  <td>4*4701MB</td>  <td>4*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e5</td> <td>1e5</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>
   <tr> <td>948h</td>  <td>90s/it</td>  <td>72.54G</td>  <td>8*4585MB</td>  <td>8*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e5</td> <td>1e5</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>
   <tr> <td rowspan="7">Pangu-2.6B</td> <td rowspan="7">SFT & Reward Data的验证集,共1万条样本</td> <td>2h</td>  <td>5.76s/it</td>  <td>67.86G</td>  <td>1*15631MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>2e8</td> <td>-</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>
   <tr> <td>2.1h</td>  <td>6.15s/it</td>  <td>67.88G</td>  <td>1*15705MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>1e5</td> <td>-</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>
   <tr> <td>4.5h</td>  <td>13.3s/it</td>  <td>81.02G</td>  <td>1*15449MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>2e8</td> <td>2e8</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>
   <tr> <td>11.5h</td>  <td>8.2s/it</td>  <td>75.89G</td>  <td>1*15299MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e5</td> <td>1e5</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>
   <tr> <td>5.5h</td>  <td>7.8s/it</td>  <td>81.16G</td>  <td>2*14851MB</td>  <td>2*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e5</td> <td>1e5</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>
   <tr> <td>6.2h</td>  <td>18.3s/it</td>  <td>97.31G</td>  <td>4*14389MB</td>  <td>4*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e5</td> <td>1e5</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>
   <tr> <td>6.6h</td>  <td>38s/it</td>  <td>118.82G</td>  <td>8*14335MB</td>  <td>8*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e5</td> <td>1e5</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>
   <tr> <td rowspan="14">ChatGLM-6B</td> <td rowspan="14">SFT & Reward Data的验证集,共1万条样本</td> <td>-</td>  <td>-</td>  <td>120.45G</td>  <td>OOM</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>1e5</td> <td>-</td> <td>1</td> <td>8</td> <td>true</td> <td>true</td> </tr>
   <tr> <td>-</td>  <td>-</td>  <td>120.48G</td>  <td>OOM</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>1e3</td> <td>-</td> <td>1</td> <td>8</td> <td>false</td> <td>true</td> </tr>
   <tr> <td>-</td>  <td>-</td>  <td>153.02G</td>  <td>OOM</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e2</td> <td>1e2</td> <td>1</td> <td>8</td> <td>false</td> <td>true</td> </tr>
   <tr> <td>-</td>  <td>-</td>  <td>154G</td>  <td>OOM</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>2e8</td> <td>2e8</td> <td>1</td> <td>8</td> <td>true</td> <td>true</td> </tr>
   <tr> <td>21.2h</td>  <td>60s/it</td>  <td>154G</td>  <td>1*10443MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>2e8</td> <td>auto</td> <td>1</td> <td>8</td> <td>true</td> <td>true</td> </tr>
   <tr> <td>21.5h</td>  <td>60s/it</td>  <td>152.81G</td>  <td>1*10409MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e5</td> <td>1e5</td> <td>1</td> <td>8</td> <td>true</td> <td>true</td> </tr>
   <tr> <td>23.5h</td>  <td>65s/it</td>  <td>153.36G</td>  <td>1*9229MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e3</td> <td>1e3</td> <td>1</td> <td>8</td> <td>true</td> <td>true</td> </tr>
   <tr> <td>14h</td>  <td>80s/it</td>  <td>158.21G</td>  <td>2*8631MB</td>  <td>2*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e3</td> <td>1e3</td> <td>1</td> <td>8</td> <td>true</td> <td>true</td> </tr>
   <tr> <td>7.8h</td>  <td>90s/it</td>  <td>168.38G</td>  <td>4*6743MB</td>  <td>4*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e3</td> <td>1e3</td> <td>1</td> <td>8</td> <td>true</td> <td>true</td> </tr>
   <tr> <td>4h</td>  <td>90s/it</td>  <td>189.34G</td>  <td>8*6729MB</td>  <td>8*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e3</td> <td>1e3</td> <td>1</td> <td>8</td> <td>true</td> <td>true</td> </tr>
   <tr> <td>1h</td>  <td>100s/it</td>  <td>189.38G</td>  <td>8*10047MB</td>  <td>8*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e3</td> <td>1e3</td> <td>4</td> <td>8</td> <td>true</td> <td>true</td> </tr>
   <tr> <td>50min</td>  <td>40s/it</td>  <td>189.39G</td>  <td>8*14763MB</td>  <td>8*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e3</td> <td>1e3</td> <td>8</td> <td>2</td> <td>true</td> <td>true</td> </tr>
   <tr> <td>35min</td>  <td>113s/it</td>  <td>189.39G</td>  <td>8*14763MB</td>  <td>8*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e3</td> <td>1e3</td> <td>8</td> <td>8</td> <td>true</td> <td>true</td> </tr>
   <tr> <td>-</td>  <td>-</td>  <td>189.34G</td>  <td>OOM</td>  <td>8*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e3</td> <td>1e3</td> <td>10</td> <td>8</td> <td>true</td> <td>true</td> </tr>
   <tr> <td rowspan="11">GLM-10B-Chinese</td> <td rowspan="11">SFT & Reward Data的验证集,共1万条样本</td> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>2e8</td> <td>2e8</td> <td>1</td> <td>8</td> <td>true</td> <td>false</td> </tr>
   <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>2e8</td> <td>auto</td> <td>1</td> <td>8</td> <td>true</td> <td>false</td> </tr>
   <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e5</td> <td>1e5</td> <td>1</td> <td>8</td> <td>true</td> <td>false</td> </tr>
   <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e3</td> <td>1e3</td> <td>1</td> <td>8</td> <td>true</td> <td>false</td> </tr>
   <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e2</td> <td>1e2</td> <td>1</td> <td>8</td> <td>true</td> <td>false</td> </tr>
   <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>2*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e2</td> <td>1e2</td> <td>1</td> <td>8</td> <td>true</td> <td>false</td> </tr>
   <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>4*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e2</td> <td>1e2</td> <td>1</td> <td>8</td> <td>true</td> <td>false</td> </tr>
   <tr> <td>-</td>  <td>-</td>  <td>OOM</td>  <td>-</td>  <td>8*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e2</td> <td>1e2</td> <td>1</td> <td>8</td> <td>true</td> <td>false</td> </tr>
   <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>4*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e2</td> <td>1e2</td> <td>1</td> <td>8</td> <td>true</td> <td>true</td> </tr>
   <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>6*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e2</td> <td>1e2</td> <td>1</td> <td>8</td> <td>true</td> <td>true</td> </tr>
   <tr> <td>-</td>  <td>-</td>  <td>OOM</td>  <td>-</td>  <td>8*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e2</td> <td>1e2</td> <td>1</td> <td>8</td> <td>true</td> <td>true</td> </tr>
</table>
</details>

**PS**: deepspeed的参数介绍和调优经验,可参见[DeepSpeed Configuration](src/resources/config/deepspeed/README.md)

### 6. LoRA实验

为验证LoRA的训练效率提升,进行了benchmarking

- 实验场景:SFT阶段训练
- 实验数据:SFT & Reward Data的验证集,共1万条样本
- 实验参数:```max_sequence_length=512, lora_alpha=1, lora_train_bias='none'```

<details>
<summary><b>LoRA实验结果</b></summary>
<table>
   <tr> <td>模型</td> <td>LoRA rank</td> <td>可训练参数量</td> <td>deepspeed</td> <td>batch size</td> <td>GPU型号和数量</td> <td>显存使用量</td> <td>单条样本耗时</td> <td>整体耗时/epoch</td> </tr>
   <tr> <td rowspan="8">Pangu-2.6B</td>  <td>-</td>  <td>2.6B</td>  <td>-</td>  <td>8</td>  <td>1*A100 80G</td>  <td>1*79421MB</td>  <td>9.66s/it</td>  <td>12.5min</td> </tr>
   <tr> <td>1000</td>  <td>1.5B</td>  <td>-</td>  <td>8</td>  <td>1*A100 80G</td>  <td>1*76129MB</td>  <td>11.61s/it</td>  <td>15min</td> </tr>
   <tr> <td>500</td>  <td>758MB</td>  <td>-</td>  <td>12</td>  <td>1*A100 80G</td>  <td>1*77179MB</td>  <td>16.2s/it</td>  <td>14min</td> </tr>
   <tr> <td>100</td>  <td>151MB</td>  <td>-</td>  <td>16</td>  <td>1*A100 80G</td>  <td>1*81103MB</td>  <td>18.6s/it</td>  <td>12min</td> </tr>
   <tr> <td>50</td>  <td>75MB</td>  <td>-</td>  <td>16</td>  <td>1*A100 80G</td>  <td>1*80809MB</td>  <td>17.8s/it</td>  <td>11.5min</td> </tr>
   <tr> <td>10</td>  <td>15MB</td>  <td>-</td>  <td>16</td>  <td>1*A100 80G</td>  <td>1*78735MB</td>  <td>17.6s/it</td>  <td>11.5min</td> </tr>
   <tr> <td>100</td>  <td>151MB</td>  <td>stage=2, w offloading</td>  <td>24</td>  <td>1*A100 80G</td>  <td>1*76933MB</td>  <td>25.5s/it</td>  <td>11min</td> </tr>
   <tr> <td>100</td>  <td>151MB</td>  <td>stage=3, w offloading</td>  <td>24</td>  <td>1*A100 80G</td>  <td>1*77259MB</td>  <td>46.5s/it</td>  <td>20min</td> </tr>
   <tr> <td rowspan="3">ChatGLM-6B</td>  <td>-</td>  <td>6.2B</td>  <td>-</td>  <td>3</td>  <td>1*A100 80G</td>  <td>1*79206MB</td>  <td>6.7s/it</td>  <td>23.5min</td> </tr>
   <tr> <td>1000</td>  <td>1.9B</td>  <td>-</td>  <td>6</td>  <td>1*A100 80G</td>  <td>1*78840MB</td>  <td>12.8s/it</td>  <td>22.5min</td> </tr>
   <tr> <td>500</td>  <td>994MB</td>  <td>-</td>  <td>6</td>  <td>1*A100 80G</td>  <td>1*68832MB</td>  <td>12.4s/it</td>  <td>21.5min</td> </tr>
</table>
</details>

================================================
FILE: docs/INSTALL_ON_WINDOWS.md
================================================
大模型相关的常见package在***Windows***上的安装方案。在以下配置上已测试成功: ```windows11 + python3.11 + cuda12.3 + cudnn8.9.5.30 + torch2.1.0+cu121 + ninja1.11.1 + cmake3.28.0-rc5 + Visual Studio 2022```

## 0. Prerequisite
#### a. 确认NVIDIA GPU架构
可根据下图
![image info](../images/torch_cuda_list.png "torch_cuda_list")

或运行```torch.cuda.get_device_capability()```获取

#### b. 软件安装
需安装Visual Studio、ninja、cmake和CUDA

- Microsoft Visual Studio 2022安装注意事项如下,`单个组件`需勾选6个额外模块 ,[参考链接](https://blog.csdn.net/feking2/article/details/130160845):

![image info](../images/microsoft_visual_studio_components.png "microsoft_visual_studio_components")

- 安装顺序:必须先安装Visual Studio,再安装CUDA

## 1. apex
目前为止没有成功

## 2. cutlass
打开Git Bash
```bash
# 设定环境变量
export CUDACXX=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v12.3/bin/nvcc

git clone https://github.com/NVIDIA/cutlass.git
mkdir build && cd build

# 显卡架构的数字根据自己情况修改
cmake .. -DCUTLASS_NVCC_ARCHS=89 -DCUTLASS_ENABLE_TESTS=OFF -DCUTLASS_ENABLE_LIBRARY=OFF -DCUTLASS_UNITY_BUILD_ENABLED=ON
cmake --build . --config Release -j 8
```
在以下commit-id: `b5d8a5d9ccb435268e2215144dca162b0fc6a826`编译成功

## 3. flash attention
打开Git Bash
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention
# IMPORTANT: do not manually copy cmake-built (or incorrect version?) cutlass code to flash-attention/csrc/cutlass
MAX_JOBS=8 python setup.py install
```
在以下commit-id: `017716451d446e464dde9aca3a3c1ed2209caaa9`安装成功

- benchmark

  - 测试代码: https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py

  - 测试设备:单张4090

  <details>
    <summary>结果</summary>
    <table>
      <tr> <td rowspan="2" style="text-align:center">is causal</td>  <td rowspan="2" style="text-align:center">head dim</td> <td rowspan="2" style="text-align:center">batch size</td> <td rowspan="2" style="text-align:center">seq length</td> <td rowspan="2" style="text-align:center">Model</td> <td colspan="3" style="text-align:center">Speed (TFLOPS/s)</td> </tr>
      <tr> <td>fwd</td>  <td>bwd</td>  <td>fwd+bwd</td> </tr>
      <tr> <td>false</td> <td>64</td> <td>32</td> <td>512</td> <td>Flash2</td> <td>149.21</td> <td>91.05</td> <td>102.46</td> </tr>
      <tr> <td>false</td> <td>64</td> <td>32</td> <td>512</td> <td>Pytorch</td> <td>20.04</td> <td>22.91</td> <td>22.01</td> </tr>
      <tr> <td>false</td> <td>64</td> <td>32</td> <td>512</td> <td>xformers.cutlass</td> <td>126.21</td> <td>74.03</td> <td>83.95</td> </tr>
      <tr> <td>false</td> <td>64</td> <td>32</td> <td>512</td> <td>xformers.flash</td> <td>161.29</td> <td>99.23</td> <td>111.49</td> </tr>
      <tr> <td>false</td> <td>64</td> <td>16</td> <td>1024</td> <td>Flash2</td> <td>167.81</td> <td>124.49</td> <td>134.40</td> </tr>
      <tr> <td>false</td> <td>64</td> <td>16</td> <td>1024</td> <td>Pytorch</td> <td>21.43</td> <td>25.68</td> <td>24.30</td> </tr>
      <tr> <td>false</td> <td>64</td> <td>16</td> <td>1024</td> <td>xformers.cutlass</td> <td>132.98</td> <td>78.03</td> <td>88.48</td> </tr>
      <tr> <td>false</td> <td>64</td> <td>16</td> <td>1024</td> <td>xformers.flash</td> <td>164.71</td> <td>124.29</td> <td>133.66</td> </tr>
      <tr> <td>false</td> <td>64</td> <td>8</td> <td>2048</td> <td>Flash2</td> <td>169.75</td> <td>123.30</td> <td>133.76</td> </tr>
      <tr> <td>false</td> <td>64</td> <td>8</td> <td>2048</td> <td>Pytorch</td> <td>25.58</td> <td>27.19</td> <td>26.71</td> </tr>
      <tr> <td>false</td> <td>64</td> <td>8</td> <td>2048</td> <td>xformers.cutlass</td> <td>136.01</td> <td>77.43</td> <td>88.29</td> </tr>
      <tr> <td>false</td> <td>64</td> <td>8</td> <td>2048</td> <td>xformers.flash</td> <td>168.51</td> <td>13.641</td> <td>144.26</td> </tr>
      <tr> <td>false</td> <td>64</td> <td>4</td> <td>4096</td> <td>Flash2</td> <td>141.27</td> <td>142.01</td> <td>141.79</td> </tr>
      <tr> <td>false</td> <td>64</td> <td>4</td> <td>4096</td> <td>Pytorch</td> <td>26.39</td> <td>0.64</td> <td>0.89</td> </tr>
      <tr> <td>false</td> <td>64</td> <td>4</td> <td>4096</td> <td>xformers.cutlass</td> <td>134.51</td> <td>67.28</td> <td>78.49</td> </tr>
      <tr> <td>false</td> <td>64</td> <td>4</td> <td>4096</td> <td>xformers.flash</td> <td>5.52</td> <td>152.40</td> <td>17.71</td> </tr>
      <tr> <td>true</td> <td>64</td> <td>32</td> <td>512</td> <td>Flash2</td> <td>2.13</td> <td>43.81</td> <td>6.65</td> </tr>
      <tr> <td>true</td> <td>64</td> <td>32</td> <td>512</td> <td>Pytorch</td> <td>7.49</td> <td>10.79</td> <td>9.58</td> </tr>
      <tr> <td>true</td> <td>64</td> <td>32</td> <td>512</td> <td>xformers.cutlass</td> <td>100.97</td> <td>56.74</td> <td>64.85</td> </tr>
      <tr> <td>true</td> <td>64</td> <td>32</td> <td>512</td> <td>xformers.flash</td> <td>2.38</td> <td>7.61</td> <td>4.68</td> </tr>
      <tr> <td>true</td> <td>64</td> <td>16</td> <td>1024</td> <td>Flash2</td> <td>3.36</td> <td>89.96</td> <td>10.73</td> </tr>
      <tr> <td>true</td> <td>64</td> <td>16</td> <td>1024</td> <td>Pytorch</td> <td>8.35</td> <td>12.27</td> <td>10.82</td> </tr>
      <tr> <td>true</td> <td>64</td> <td>16</td> <td>1024</td> <td>xformers.cutlass</td> <td>115.96</td> <td>58.43</td> <td>68.08</td> </tr>
      <tr> <td>true</td> <td>64</td> <td>16</td> <td>1024</td> <td>xformers.flash</td> <td>3.06</td> <td>90.85</td> <td>9.87</td> </tr>
      <tr> <td>true</td> <td>64</td> <td>8</td> <td>2048</td> <td>Flash2</td> <td>5.24</td> <td>112.91</td> <td>16.44</td> </tr>
      <tr> <td>true</td> <td>64</td> <td>8</td> <td>2048</td> <td>Pytorch</td> <td>8.76</td> <td>13.39</td> <td>11.63</td> </tr>
      <tr> <td>true</td> <td>64</td> <td>8</td> <td>2048</td> <td>xformers.cutlass</td> <td>123.37</td> <td>67.61</td> <td>77.63</td> </tr>
      <tr> <td>true</td> <td>64</td> <td>8</td> <td>2048</td> <td>xformers.flash</td> <td>4.37</td> <td>115.97</td> <td>13.99</td> </tr>
      <tr> <td>true</td> <td>64</td> <td>4</td> <td>4096</td> <td>Flash2</td> <td>7.26</td> <td>134.79</td> <td>22.40</td> </tr>
      <tr> <td>true</td> <td>64</td> <td>4</td> <td>4096</td> <td>Pytorch</td> <td>0.22</td> <td>0.35</td> <td>0.30</td> </tr>
      <tr> <td>true</td> <td>64</td> <td>4</td> <td>4096</td> <td>xformers.cutlass</td> <td>129.03</td> <td>70.07</td> <td>80.60</td> </tr>
      <tr> <td>true</td> <td>64</td> <td>4</td> <td>4096</td> <td>xformers.flash</td> <td>6.25</td> <td>103.89</td> <td>19.02</td> </tr>
    </table>
  </details>


## 4. xformers
打开任意终端,如:cmd、powershell或git bash
```cmd
pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu121
```

## 5. deepspeed
打开Git Bash
```bash
git clone https://github.com/microsoft/DeepSpeed.git
cd deepspeed
rm -rf build
# 显卡架构的数字根据自己情况修改
TORCH_CUDA_ARCH_LIST=8.9 DS_BUILD_OPS=0 DS_BUILD_AIO=0 DS_BUILD_CCL_COMM=1 DS_BUILD_CPU_ADAM=1 DS_BUILD_CPU_ADAGRAD=1 DS_BUILD_TRANSFORMER=1 DS_BUILD_TRANSFORMER_INFERENCE=0 DS_BUILD_INFERENCE_CORE_OPS=0 DS_BUILD_CUTLASS_OPS=0 python setup.py build_ext -j8 bdist_wheel 2>&1 | tee build.log
# 完成后会在dist目录生成whl文件
pip install dist/deepspeed-0.12.4+a3926bbb-cp311-cp311-win_amd64.whl
```
在以下commit-id: `a3926bbbf6d0025b5c6076a280e6b91ebd08aada`安装成功

***PS***:由于deepspeed需要调用NCCL backend进行分布式训练,但NCCL不支持Windows,所以目前虽然可以安装deepspeed,但无法进行训练


================================================
FILE: examples/data_prepare.sh
================================================
#!/bin/bash

MODEL="pangu-2.6B"

ROOT="/mnt/sfevol775196/sunzeye273"
#ROOT="/mnt/share-pa002-vol682688-prd/sunzeye273"
#ROOT="/mnt/pa002-28359-vol543625-private"
#ROOT="/root/autodl-tmp/"
DATR_DIR=$ROOT/Data/chatgpt/raw
#MAIN=$ROOT/Code/chatgpt/src/data_prepare.py
MAIN=$ROOT/Code/RLHF/src/data_prepare.py
MODEL_PATH=$ROOT/Data/models/$MODEL
#MODEL_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$MODEL
OUTPUT_DIR=$ROOT/Data/chatgpt/processed

#cd $ROOT/Code/chatgpt || exit
cd $ROOT/Code/RLHF || exit
mkdir -p $OUTPUT_DIR

python $MAIN \
  --data_dir $DATR_DIR \
  --output_dir $OUTPUT_DIR \
  --model_name_or_path $MODEL_PATH

================================================
FILE: examples/eval_pretrain.sh
================================================
#!/bin/bash

# C-Eval and MMLU benchamarks
TASK="ceval"
MODEL="llama-7B"
ROOT="/mnt/sfevol775196/sunzeye273"
#ROOT="/mnt/share-pa002-vol682688-prd/sunzeye273"
#ROOT="/mnt/pa002-28359-vol543625-private"
#ROOT="/root/autodl-tmp"
DATR_DIR=$ROOT/Data/chatgpt/raw/$TASK
#MAIN=$ROOT/Code/chatgpt/src/eval_pretrain.py
MAIN=$ROOT/Code/RLHF/src/eval_pretrain.py
MODEL_PATH=$ROOT/Data/models/$MODEL
#MODEL_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$MODEL
OUTPUT_DIR=$ROOT/Data/chatgpt/output/pretrain/$MODEL
EVAL_FILENAME="val"
TRAIN_FILENAME="dev"
CHECKPOINT=$ROOT/Data/chatgpt/output/pretrain/$MODEL
SHOTS=5
MAX_LENGTH=1280

cd $ROOT/Code/RLHF || exit
#    cd $ROOT/Code/chatgpt || exit
mkdir -p $OUTPUT_DIR

#CUDA_VISIBLE_DEVICES=1 deepspeed --master_port 5008 $MAIN \
#deepspeed --num_gpus 1 $MAIN \
python $MAIN \
  --device_map "auto" \
  --data_dir $DATR_DIR \
  --output_dir $OUTPUT_DIR \
  --model_name_or_path $MODEL_PATH \
  --task $TASK \
  --train_filename $TRAIN_FILENAME \
  --eval_filename $EVAL_FILENAME \
  --checkpoint $CHECKPOINT \
  --max_length $MAX_LENGTH \
  --max_few_shot $SHOTS \
  --max_length_generation 1 \
  > out/eval_pretrain_${TASK}_${MODEL}_${EVAL_FILENAME}_${SHOTS}-shots_${MAX_LENGTH}_"`date "+%Y-%m-%d-%H:%M:%S"`".log 2>&1 &

## Traditional NLP benchmark Evaluations
#for TASK in  "cluewsc2020" "afqmc" "csl" "iflytek" "ocnli" "cmnli" "tnews" "c3" "cmrc2018" "chid"
#do
#  for MODEL in "pangu-350M" "pangu-2.6B" "glm-350M-chinese" "glm-10B-chinese" "pangu-13B"
#  do
#    ROOT="/mnt/sfevol775196/sunzeye273"
#    #ROOT="/mnt/share-pa002-vol682688-prd/sunzeye273"
#    #ROOT="/mnt/pa002-28359-vol543625-private"
#    #ROOT="/root/autodl-tmp"
#    DATR_DIR=$ROOT/Data/chatgpt/raw/$TASK
##    MAIN=$ROOT/Code/chatgpt/src/eval_pretrain.py
#    MAIN=$ROOT/Code/RLHF/src/eval_pretrain.py
#    MODEL_PATH=$ROOT/Data/models/$MODEL
#    #MODEL_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$MODEL
#    OUTPUT_DIR=$ROOT/Data/chatgpt/output/pretrain/$MODEL
#    EVAL_FILENAME="dev.json"
#    TRAIN_FILENAME="train.json"
#    case $MODEL in
#       "pangu-2.6B")
#          BATCH_SIZE=8
#          ;;
#       *)
#         BATCH_SIZE=32
#         ;;
#    esac
#
#    cd $ROOT/Code/RLHF || exit
##    cd $ROOT/Code/chatgpt || exit
#    mkdir -p $OUTPUT_DIR
#
#    #CUDA_VISIBLE_DEVICES=1 deepspeed --master_port 5008 $MAIN \
#    #deepspeed --num_gpus 1 $MAIN \
#    python $MAIN \
#      --device_map auto \
#      --data_dir $DATR_DIR \
#      --output_dir $OUTPUT_DIR \
#      --model_name_or_path $MODEL_PATH \
#      --task $TASK \
#      --max_length 512 \
#      --train_filename $TRAIN_FILENAME \
#      --eval_filename $EVAL_FILENAME \
#      --eval_batch_size $BATCH_SIZE \
#      --top_p 0.8 \
#      --temperature 0.8 \
#      --num_return_sequences 1 \
#      --max_length_generation 100 \
#      > out/eval_pretrain_${MODEL}_${TASK}_"`date "+%Y-%m-%d-%H:%M:%S"`".log 2>&1
#  done
#done

================================================
FILE: examples/eval_reward.sh
================================================
#!/bin/bash

MODEL="chatglm-6B"

ROOT="/mnt/sfevol775196/sunzeye273"
#ROOT="/mnt/share-pa002-vol682688-prd/sunzeye273"
#ROOT="/mnt/pa002-28359-vol543625-private"
#ROOT="/root/autodl-tmp"
DATR_DIR=$ROOT/Data/chatgpt/processed
#MAIN=$ROOT/Code/chatgpt/src/train_reward.py
MAIN=$ROOT/Code/RLHF/src/train_reward.py
TOKENIZER_PATH=$ROOT/Data/models/$MODEL
#TOKENIZER_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$MODEL
MODEL_PATH=$ROOT/Data/chatgpt/output/reward/${MODEL}
OUTPUT_DIR=$ROOT/Data/chatgpt/output/reward/$MODEL
EVAL_FILENAME="dev_data_external_v1.jsonl"
#CHECKPOINT="${ROOT}/Data/chatgpt/output/reward/${MODEL}/pytorch_modelstar.bin"

#cd $ROOT/Code/chatgpt || exit
cd $ROOT/Code/RLHF || exit
mkdir -p $OUTPUT_DIR

#python $MAIN \
CUDA_LAUNCH_BLOCKING=1 deepspeed $MAIN \
  --data_dir $DATR_DIR \
  --output_dir $OUTPUT_DIR \
  --model_name_or_path $MODEL_PATH \
  --tokenizer_path $TOKENIZER_PATH \
  --max_length 512 \
  --logging_steps 10 \
  --deepspeed_config "stage-3.json" \
  --do_eval \
  --eval_filename $EVAL_FILENAME \
  --eval_batch_size 20 \
  > out/eval_reward_${MODEL}_"`date "+%Y-%m-%d-%H:%M:%S"`".log 2>&1 &


================================================
FILE: examples/pred_pretrain.sh
================================================
#!/bin/bash

MODEL="llama-7B"

ROOT="/mnt/sfevol775196/sunzeye273"
#ROOT="/mnt/share-pa002-vol682688-prd/sunzeye273"
#ROOT="/mnt/pa002-28359-vol543625-private"
#ROOT="/root/autodl-tmp"
DATR_DIR=$ROOT/Data/chatgpt/processed
#MAIN=$ROOT/Code/chatgpt/src/pretrain.py
MAIN=$ROOT/Code/RLHF/src/pretrain.py
MODEL_PATH=$ROOT/Data/models/$MODEL
#MODEL_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$MODEL
OUTPUT_DIR=$ROOT/Data/chatgpt/output/pretrain/$MODEL/checkpoint-2000
CHECKPOINT=$OUTPUT_DIR/pytorch_model.bin
TEST_FILENAME="test_prompts.jsonl"
OUTPUT_FILENAME="output_${MODEL}.jsonl"

#cd $ROOT/Code/chatgpt || exit
cd $ROOT/Code/RLHF || exit
mkdir -p $OUTPUT_DIR

#CUDA_VISIBLE_DEVICES=1 deepspeed --master_port 5008 $MAIN \
python $MAIN \
  --device_map "auto" \
  --data_dir $DATR_DIR \
  --output_dir $OUTPUT_DIR \
  --model_name_or_path $MODEL_PATH \
  --checkpoint $CHECKPOINT \
  --max_length 1024 \
  --max_length_generation 512 \
  --bits 16 \
  --do_pred \
  --test_filename $TEST_FILENAME \
  --output_filename $OUTPUT_FILENAME \
  > out/pred_pretrain_${MODEL}_"`date "+%Y-%m-%d-%H:%M:%S"`".log 2>&1 &


================================================
FILE: examples/pred_reward.sh
================================================
#!/bin/bash

MODEL="chatglm2-6B"

ROOT="/mnt/sfevol775196/sunzeye273"
#ROOT="/mnt/share-pa002-vol682688-prd/sunzeye273"
#ROOT="/mnt/pa002-28359-vol543625-private"
#ROOT="/root/autodl-tmp"
DATR_DIR=$ROOT/Data/chatgpt/processed
#MAIN=$ROOT/Code/chatgpt/src/train_reward.py
MAIN=$ROOT/Code/RLHF/src/train_reward.py
TOKENIZER_PATH=$ROOT/Data/models/$MODEL
#TOKENIZER_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$MODEL
MODEL_PATH=$ROOT/Data/chatgpt/output/sft/${MODEL}
OUTPUT_DIR=$ROOT/Data/chatgpt/output/reward/$MODEL
TEST_FILENAME="test_data_external_v1.jsonl"
OUTPUT_FILENAME="result_external_v1.tsv"
CHECKPOINT="${ROOT}/Data/chatgpt/output/reward/${MODEL}/pytorch_modelstar.bin"

#cd $ROOT/Code/chatgpt || exit
cd $ROOT/Code/RLHF || exit
mkdir -p $OUTPUT_DIR

python $MAIN \
  --local_rank 0 \
  --device_map "auto" \
  --data_dir $DATR_DIR \
  --output_dir $OUTPUT_DIR \
  --model_name_or_path $MODEL_PATH \
  --tokenizer_path $TOKENIZER_PATH \
  --checkpoint $CHECKPOINT \
  --max_length 512 \
  --logging_steps 50 \
  --do_pred \
  --test_filename $TEST_FILENAME \
  --eval_batch_size 256 \
  --output_filename $OUTPUT_FILENAME \
  > out/pred_reward_${MODEL}_"`date "+%Y-%m-%d-%H:%M:%S"`".log 2>&1 &


================================================
FILE: examples/pretrain.sh
================================================
#!/bin/bash

MODEL="llama-7B"

ROOT="/mnt/sfevol775196/sunzeye273"
#ROOT="/mnt/share-pa002-vol682688-prd/sunzeye273"
#ROOT="/mnt/pa002-28359-vol543625-private"
#ROOT="/root/autodl-tmp"
DATR_DIR=$ROOT/Data/chatgpt/processed
#MAIN=$ROOT/Code/chatgpt/src/pretrain.py
MAIN=$ROOT/Code/RLHF/src/pretrain_wo_trainer.py
MODEL_PATH=$ROOT/Data/models/$MODEL
#MODEL_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$MODEL
OUTPUT_DIR=$ROOT/Data/chatgpt/output/pretrain/$MODEL
TRAIN_FILENAME="pretrain_data_v1.jsonl"
EVAL_FILENAME="pretrain_eval_data_v1.jsonl"

#cd $ROOT/Code/chatgpt || exit
cd $ROOT/Code/RLHF || exit
mkdir -p $OUTPUT_DIR

#CUDA_VISIBLE_DEVICES=1 deepspeed --master_port 5008 $MAIN \
#python $MAIN \
CUDA_LAUNCH_BLOCKING=1 deepspeed $MAIN \
  --data_dir $DATR_DIR \
  --output_dir $OUTPUT_DIR \
  --model_name_or_path $MODEL_PATH \
  --max_length 2048 \
  --logging_steps 50 \
  --save_steps 1000 \
  --metric_for_best_model "eval_rougeLsum" \
  --learning_rate 3e-4 \
  --do_train \
  --train_filename $TRAIN_FILENAME \
  --concat_samples \
  --num_epochs 2 \
  --train_batch_size 8 \
  --gradient_accumulation_steps 8 \
  --warmup_steps 100 \
  --gradient_checkpointing \
  --deepspeed_config "stage-3-no_trainer.json" \
  --do_eval \
  --eval_filename $EVAL_FILENAME \
  --eval_batch_size 8 \
  --eval_steps 500 \
  > out/pretrain_${MODEL}_"`date "+%Y-%m-%d-%H:%M:%S"`".log 2>&1 &


================================================
FILE: examples/train_dpo.sh
================================================
#!/bin/bash

MODEL="chatglm2-6B"

#ROOT="/mnt/sfevol775196/sunzeye273"
ROOT="/mnt/pa002-28359-vol543625-private"
#ROOT="/root/autodl-tmp"
DATR_DIR=$ROOT/Data/chatgpt/processed
#MAIN=$ROOT/Code/chatgpt/src/train_dpo.py
MAIN=$ROOT/Code/RLHF/src/train_dpo.py
#TOKENIZER_PATH=$ROOT/Data/models/$MODEL
TOKENIZER_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$MODEL
MODEL_PATH=$ROOT/Data/chatgpt/output/sft/$MODEL
REFERENCE_MODEL_PATH=$ROOT/Data/chatgpt/output/sft/$MODEL
OUTPUT_DIR=$ROOT/Data/chatgpt/output/dpo/$MODEL
TRAIN_FILENAME="sft_train_v2.1.jsonl"
EVAL_FILENAME="sft_eval_v1.1.jsonl"
TEST_FILENAME="${TRAIN_FILENAME},${EVAL_FILENAME}"
OUTPUT_FILENAME="dpo_logps_v2.1.bin"

#cd $ROOT/Code/chatgpt || exit
cd $ROOT/Code/RLHF || exit
mkdir -p $OUTPUT_DIR

if [ -f $OUTPUT_DIR/$OUTPUT_FILENAME ]
then
    echo "${OUTPUT_DIR}/${OUTPUT_FILENAME} already exists, skipping prediction stage"
else
    python $MAIN \
      --local_rank 0 \
      --device_map "cuda:0" \
      --data_dir $DATR_DIR \
      --output_dir $OUTPUT_DIR \
      --tokenizer_path $TOKENIZER_PATH \
      --model_name_or_path $MODEL_PATH \
      --max_length 512 \
      --logging_steps 10 \
      --eval_batch_size 32 \
      --do_pred \
      --test_filename $TEST_FILENAME \
      --output_filename $OUTPUT_FILENAME \
      > out/pred_dpo_${MODEL}_"`date "+%Y-%m-%d-%H:%M:%S"`".log 2>&1
fi

#CUDA_VISIBLE_DEVICES=1 deepspeed --master_port 5008 $MAIN \
#python $MAIN \
CUDA_LAUNCH_BLOCKING=1 deepspeed $MAIN \
  --data_dir $DATR_DIR \
  --output_dir $OUTPUT_DIR \
  --tokenizer_path $TOKENIZER_PATH \
  --model_name_or_path $MODEL_PATH \
  --max_length 512 \
  --logging_steps 10 \
  --save_steps 100 \
  --learning_rate 1e-5 \
  --do_train \
  --train_filename $TRAIN_FILENAME \
  --train_batch_size 8 \
  --gradient_accumulation_steps 8 \
  --num_epochs 5 \
  --gradient_checkpointing \
  --deepspeed_config "stage-3.json" \
  --do_eval \
  --eval_filename $EVAL_FILENAME \
  --eval_batch_size 32 \
  --output_filename $OUTPUT_FILENAME \
  > out/train_dpo_${MODEL}_"`date "+%Y-%m-%d-%H:%M:%S"`".log 2>&1 &


================================================
FILE: examples/train_reward.sh
================================================
#!/bin/bash

MODEL="chatglm-6B"

ROOT="/mnt/sfevol775196/sunzeye273"
#ROOT="/mnt/share-pa002-vol682688-prd/sunzeye273"
#ROOT="/mnt/pa002-28359-vol543625-private"
#ROOT="/root/autodl-tmp"
DATR_DIR=$ROOT/Data/chatgpt/processed
#MAIN=$ROOT/Code/chatgpt/src/train_reward.py
MAIN=$ROOT/Code/RLHF/src/train_reward.py
TOKENIZER_PATH=$ROOT/Data/models/$MODEL
#TOKENIZER_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$MODEL
MODEL_PATH=$ROOT/Data/chatgpt/output/sft/${MODEL}
OUTPUT_DIR=$ROOT/Data/chatgpt/output/reward/$MODEL
TRAIN_FILENAME="train_data_external_v1.jsonl"
EVAL_FILENAME="dev_data_external_v1.jsonl"
#CHECKPOINT="${ROOT}/Data/chatgpt/output/sft/${MODEL}/pytorch_modelstar.bin"

#cd $ROOT/Code/chatgpt || exit
cd $ROOT/Code/RLHF || exit
mkdir -p $OUTPUT_DIR

#python $MAIN \
CUDA_LAUNCH_BLOCKING=1 deepspeed $MAIN \
  --data_dir $DATR_DIR \
  --output_dir $OUTPUT_DIR \
  --model_name_or_path $MODEL_PATH \
  --tokenizer_path $TOKENIZER_PATH \
  --max_length 512 \
  --logging_steps 10 \
  --save_steps 100 \
  --metric_for_best_model "eval_accuracy" \
  --learning_rate 1e-5 \
  --do_train \
  --train_filename $TRAIN_FILENAME \
  --train_batch_size 24 \
  --gradient_accumulation_steps 8 \
  --num_epochs 5 \
  --gradient_checkpointing \
  --deepspeed_config "stage-3.json" \
  --do_eval \
  --eval_filename $EVAL_FILENAME \
  --eval_batch_size 24 \
  > out/train_reward_${MODEL}_"`date "+%Y-%m-%d-%H:%M:%S"`".log 2>&1 &


================================================
FILE: examples/train_rlhf.sh
================================================
#!/bin/bash

REWARD_MODEL="pangu-350M"
SFT_MODEL="pangu-2.6B"

ROOT="/mnt/sfevol775196/sunzeye273"
#ROOT="/mnt/share-pa002-vol682688-prd/sunzeye273"
#ROOT="/mnt/pa002-28359-vol543625-private"
#ROOT="/root/autodl-tmp/"
DATR_DIR=$ROOT/Data/chatgpt/processed
#MAIN=$ROOT/Code/chatgpt/src/train_rlhf.py
MAIN=$ROOT/Code/RLHF/src/train_rlhf.py
#ACCELERATE_CONFIG=$ROOT/Code/RLHF/src/resources/ppo_model/default_accelerate_config.yaml
TOKENIZER_PATH=$ROOT/Data/models/$REWARD_MODEL
ACTOR_MODEL_PATH=$ROOT/Data/chatgpt/output/sft/$SFT_MODEL
CRITIC_MODEL_PATH=$ROOT/Data/models/$REWARD_MODEL
#CRITIC_MODEL_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$REWARD_MODEL
CRITIC_CHECKPOINT=$ROOT/Data/chatgpt/output/reward/$REWARD_MODEL/pytorch_model.bin
OUTPUT_DIR=$ROOT/Data/chatgpt/output/rlhf/$SFT_MODEL
TRAIN_FILENAME="train_data_external_v1.jsonl"
PRETRAIN_FILENAME="train_data_external_v1.jsonl"

#cd $ROOT/Code/chatgpt || exit
cd $ROOT/Code/RLHF || exit
mkdir -p $OUTPUT_DIR

#python $MAIN \
#accelerate launch --main_process_port 5007 --config_file $ACCELERATE_CONFIG $MAIN \
CUDA_LAUNCH_BLOCKING=1 deepspeed $MAIN \
  --data_dir $DATR_DIR \
  --output_dir $OUTPUT_DIR \
  --tokenizer_path $TOKENIZER_PATH \
  --actor_model_path $ACTOR_MODEL_PATH \
  --critic_model_path $CRITIC_MODEL_PATH \
  --critic_checkpoint $CRITIC_CHECKPOINT \
  --max_length 512 \
  --max_gen_length 256 \
  --logging_steps 10 \
  --do_train \
  --train_filename $TRAIN_FILENAME \
  --pretrain_filename $PRETRAIN_FILENAME \
  --actor_learning_rate 1e-5 \
  --critic_learning_rate 1e-5 \
  --lr_scheduler_type cosine \
  --train_batch_size 4 \
  --ppo_train_batch_size 4 \
  --gradient_accumulation_steps 16 \
  --num_epochs 1 \
  --ppo_epochs 1 \
  --enable_hybrid_engine \
  --actor_zero_stage 3 \
  --critic_zero_stage 3 \
  --offload_reference_model \
  --actor_gradient_checkpointing \
  --critic_gradient_checkpointing \
  --release_inference_cache \
  > out/train_rlhf_${SFT_MODEL}_"`date "+%Y-%m-%d-%H:%M:%S"`".log 2>&1 &


================================================
FILE: examples/train_sft.sh
================================================
#!/bin/bash

MODEL="chatglm2-6B"

#ROOT="/mnt/sfevol775196/sunzeye273"
ROOT="/mnt/pa002-28359-vol543625-private"
#ROOT="/root/autodl-tmp"
DATR_DIR=$ROOT/Data/chatgpt/processed
#MAIN=$ROOT/Code/chatgpt/src/train_sft.py
MAIN=$ROOT/Code/RLHF/src/train_sft.py
#TOKENIZER_PATH=$ROOT/Data/models/$MODEL
TOKENIZER_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$MODEL
#MODEL_PATH=$ROOT/Data/models/$MODEL
MODEL_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$MODEL
OUTPUT_DIR=$ROOT/Data/chatgpt/output/sft/$MODEL
TRAIN_FILENAME="train_data_external_v1.jsonl"
EVAL_FILENAME="dev_data_external_v1.jsonl"

#cd $ROOT/Code/chatgpt || exit
cd $ROOT/Code/RLHF || exit
mkdir -p $OUTPUT_DIR

#CUDA_VISIBLE_DEVICES=1 deepspeed --master_port 5008 $MAIN \
#python $MAIN \
CUDA_LAUNCH_BLOCKING=1 deepspeed $MAIN \
  --data_dir $DATR_DIR \
  --output_dir $OUTPUT_DIR \
  --tokenizer_path $TOKENIZER_PATH \
  --model_name_or_path $MODEL_PATH \
  --max_length 512 \
  --logging_steps 10 \
  --save_steps 100 \
  --metric_for_best_model "eval_rougeLsum" \
  --learning_rate 1e-5 \
  --do_train \
  --train_filename $TRAIN_FILENAME \
  --concat_samples \
  --train_batch_size 8 \
  --gradient_accumulation_steps 8 \
  --num_epochs 5 \
  --gradient_checkpointing \
  --deepspeed_config "stage-3.json" \
  --do_eval \
  --eval_filename $EVAL_FILENAME \
  --eval_batch_size 32 \
  > out/train_sft_${MODEL}_"`date "+%Y-%m-%d-%H:%M:%S"`".log 2>&1 &


================================================
FILE: notebook/data_prepare.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, time, re, random, glob, json, jieba, copy\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "from transformers import (\n",
    "    AutoModelForCausalLM,\n",
    "    AutoModelForSeq2SeqLM,\n",
    "    AutoModelForMultipleChoice,\n",
    "    AutoTokenizer,\n",
    "    Trainer,\n",
    "    TrainingArguments,\n",
    "    default_data_collator,\n",
    "    TextGenerationPipeline\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device=\"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "from sys import platform\n",
    "if platform == \"linux\" or platform == \"linux2\":\n",
    "    # linux\n",
    "    root = \"/mnt/sfevol775196/sunzeye273/Data\"\n",
    "#     root = \"/mnt/share-pa002-vol682688-prd/sunzeye273/Data\"\n",
    "#     root = \"/mnt/pa002-28359-vol543625-private/Data\"\n",
    "#     root = \"/root/autodl-tmp/Data\"\n",
    "elif platform == \"darwin\":\n",
    "    # OS X\n",
    "    root = \"/Users/zeyesun/Documents/Data\"\n",
    "elif platform == \"win32\":\n",
    "    # Windows...\n",
    "    root = \"D:\\\\Data\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CLEAN_TEXT_PATTERN = re.compile(r\"[\\r\\n]\")\n",
    "\n",
    "def clean_text(text):\n",
    "    return CLEAN_TEXT_PATTERN.sub(\"\", text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"pangu-350M\"\n",
    "# model_name = \"chatglm-6B\"\n",
    "model_name_or_path = os.path.join(root, \"models\", model_name)\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_cache=False, trust_remote_code=True)\n",
    "print(tokenizer.special_tokens_map)\n",
    "print(tokenizer.all_special_ids)\n",
    "print(\n",
    "    f\"unk: {tokenizer.unk_token_id}\\n\",\n",
    "    f\"pad: {tokenizer.pad_token_id}\\n\",\n",
    "    f\"bos: {tokenizer.bos_token_id}\\n\",\n",
    "    f\"eos: {tokenizer.eos_token_id}\\n\",\n",
    "    f\"sep: {tokenizer.sep_token_id}\\n\",\n",
    "    f\"mask: {tokenizer.mask_token_id}\\n\",\n",
    "#     f\"eop: {tokenizer.eop_token_id}\\n\"\n",
    "#     f\"sop: {tokenizer.sop_token_id}\\n\"\n",
    "#     f\"cls: {tokenizer.cls_token_id}\"\n",
    ") "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Processing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### weibo_summary_comments_json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = time.time()\n",
    "fi = os.path.join(root, \"raw\", \"weibo_summary_comments_json.json\")\n",
    "fo = os.path.join(root, \"chatgpt\", \"processed\", \"weibo_summary_comments.jsonl\")\n",
    "ct = 0\n",
    "with open(fo, \"w\", encoding=\"utf-8\") as w:\n",
    "    with open(fi, \"r\", encoding=\"utf-8\") as r:\n",
    "        while True:\n",
    "            line = r.readline()\n",
    "            if not line\n",
    "                break\n",
    "            \n",
    "            item = json.loads(line.strip(\"\\n\"))\n",
    "            article = item['article'].replace(\" \", \"\")\n",
    "            abstract = item['abstract'].replace(\" \", \"\")\n",
    "            prompt = f\"新闻内容:{article}{tokenizer.sep_token}摘要:{abstract}{tokenizer.sep_token}评论:\"\n",
    "            answers = [\n",
    "                {\n",
    "                    \"answer\": k.replace(\" \", \"\"), \n",
    "                    \"score\": int(v)\n",
    "                } for (k, v) in sorted(item['comments'], key=lambda x: (int(x[1]), len(x[0])), reverse=True)\n",
    "            ]\n",
    "            w.write(json.dumps({\"prompt\": prompt, \"answers\": answers}, ensure_ascii=False)+'\\n')\n",
    "            ct += 1\n",
    "print(f\"length: {ct}, time taken: {time.time()-t} s\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### couplets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t1 = time.time()\n",
    "fi = os.path.join(root, \"raw\", \"couplets.txt\")\n",
    "fo = os.path.join(root, \"chatgpt\", \"processed\", \"couplets.jsonl\")\n",
    "l2 = []\n",
    "nexts = dict()\n",
    "with open(fi, \"r\", encoding=\"utf-8\") as r:\n",
    "    while True:\n",
    "        line = r.readline()\n",
    "        if not line:\n",
    "            break\n",
    "        line = line.strip(\"\\n\")\n",
    "        idx = len(line) // 2\n",
    "        prompt = line[:idx]\n",
    "        answer = line[idx+1:]\n",
    "        answers = [{\"answer\": answer, \"score\": 1}]\n",
    "        l2.append({\"prompt\": f\"上联:{prompt}{tokenizer.sep_token}下联:\", \"answers\": answers})\n",
    "        length = len(answer)\n",
    "        if length not in nexts:\n",
    "            nexts[length] = list()\n",
    "        nexts[length].append(answer)\n",
    "t2 = time.time()\n",
    "print(f\"length: {len(l2)}, # different lengths: {len(nexts)}, time taken: {t2-t1} s\")\n",
    "with open(fo, \"w\", encoding=\"utf-8\") as w:\n",
    "    for i, l in tqdm(enumerate(l2), desc=\"Processing Couplets\"):\n",
    "        answer = l['answers'][0]\n",
    "        length = len(answer['answer'])\n",
    "        # 上下联长度一样\n",
    "        nexts_tmp = set(nexts[length])\n",
    "        nexts_tmp.remove(answer['answer'])\n",
    "        nexts_tmp = set(nexts[length]).difference(set([answer['answer']]))\n",
    "#         nexts_tmp.remove(answer['answer'])\n",
    "        answers.extend([{\"answer\": fa, \"score\": 0} for fa in random.sample(nexts_tmp, 2)])\n",
    "        # 上下联长度不一样\n",
    "        keys = set(nexts.keys())\n",
    "        keys.remove(length)\n",
    "        answers.extend([{\"answer\": random.choice(nexts[key]), \"score\": -1} for key in random.sample(keys, 2)])\n",
    "#         answers = sorted(answers, key=lambda x: x['score'], reverse=True)\n",
    "        w.write(json.dumps({\"prompt\": l['prompt'], \"answers\": answers}, ensure_ascii=False)+'\\n')\n",
    "#         if i % 1000 == 0:\n",
    "#             print(f\"{i} samples processed, time taken: {time.time()-t2} s\")\n",
    "print(f\"length: {len(l2)}, time taken: {time.time()-t2} s\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### zhidao"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = time.time()\n",
    "fp = os.path.join(root, \"raw\", \"zhidao\", \"*.csv\")\n",
    "fo = os.path.join(root, \"chatgpt\", \"processed\", \"zhidao.jsonl\")\n",
    "ct = 0\n",
    "with open(fo, \"w\", encoding=\"utf-8\") as w:\n",
    "    for fi in glob.glob(fp):\n",
    "        ct = 0\n",
    "        df = pd.read_csv(fi).sort_values(by=[\"title\", \"is_best\"], ascending=False)\n",
    "        prev_title = None\n",
    "        prev_prompt = None\n",
    "        for _, val in df.iterrows():\n",
    "            if isinstance(val['question'], str) and val['question'] != val['title']:\n",
    "                prompt = f\"问题:{val['title']}{tokenizer.sep_token}内容:{val['question']}{tokenizer.sep_token}回答:\"\n",
    "            else:\n",
    "                prompt = f\"问题:{val['title']}{tokenizer.sep_token}回答:\"\n",
    "            if prev_title is not None and prev_title == val['title']:\n",
    "                answers.append({\"answer\": val['reply'], \"score\": val['is_best']})\n",
    "            else:\n",
    "                if prev_title is not None:\n",
    "#                     l3.append({\"prompt\": prev_prompt, \"answers\": copy.deepcopy(answers)})\n",
    "                    w.write(json.dumps({\"prompt\": prev_prompt, \"answers\": answers}, ensure_ascii=False)+'\\n')\n",
    "                answers = [{\"answer\": val['reply'], \"score\": val['is_best']}]\n",
    "            prev_prompt = prompt\n",
    "            prev_title = val['title']\n",
    "            ct += 1\n",
    "#         l3.append({\"prompt\": prev_prompt, \"answers\": copy.deepcopy(answers)})\n",
    "        w.write(json.dumps({\"prompt\": prev_prompt, \"answers\": answers}, ensure_ascii=False)+'\\n')\n",
    "        print(f\"finished processing {os.path.basename(fi)}\")\n",
    "print(f\"length: {ct}, time taken: {time.time()-t} s\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### JDData"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from html.parser import HTMLParser\n",
    "class MyHTMLParser(HTMLParser):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "       #Initializing lists\n",
    "        self.start_tags = list()\n",
    "        self.end_tags = list()\n",
    "        self.start_end_tags = list()\n",
    "        self.data_list = list()\n",
    "    #HTML Parser Methods\n",
    "    def handle_starttag(self, startTag, attrs):\n",
    "        self.start_tags.append(startTag)\n",
    "    def handle_endtag(self, endTag):\n",
    "        self.end_tags.append(endTag)\n",
    "    def handle_startendtag(self,startendTag, attrs):\n",
    "        self.start_end_tags.append(startendTag)\n",
    "    def handle_data(self, data):\n",
    "        self.data_list.append(data)\n",
    "        \n",
    "t = time.time()\n",
    "fi = os.path.join(root, \"raw\", \"JDData\", \"*.data*\")\n",
    "# fo = os.path.join(root, \"chatgpt\", \"processed\", \"zhidao.jsonl\")\n",
    "ct = 0\n",
    "with open(fo, \"w\", encoding=\"utf-8\") as w:\n",
    "    for fi in glob.glob(fp):\n",
    "        ct = 0\n",
    "        with open(fi, \"r\", encoding=\"gbk\") as r:\n",
    "            line = r.readline()\n",
    "            items = line.strip(\"\\n\").split(\"\\t\")\n",
    "            parser = MyHTMLParser()\n",
    "            parser.feed(items[1])\n",
    "            for t, d in zip(parser.start_tags, parser.data_list):\n",
    "                print(f\"{t}: {d}\")\n",
    "#                 prompt = f\"问题:{val['title']}{tokenizer.sep_token}内容:{val['question']}{tokenizer.sep_token}回答:\"\n",
    "#                 answers.append({\"answer\": val['reply'], \"score\": val['is_best']})\n",
    "            ct += 1\n",
    "#         l3.append({\"prompt\": prev_prompt, \"answers\": copy.deepcopy(answers)})\n",
    "#         w.write(json.dumps({\"prompt\": prev_prompt, \"answers\": answers}, ensure_ascii=False)+'\\n')\n",
    "        print(f\"finished processing {os.path.basename(fi)}\")\n",
    "print(f\"length: {ct}, time taken: {time.time()-t} s\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### yf_amazon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = time.time()\n",
    "fi = os.path.join(root, \"raw\", \"yf_amazon\", \"products.csv\")\n",
    "dfp = pd.read_csv(fi)\n",
    "fi = os.path.join(root, \"raw\", \"yf_amazon\", \"ratings.csv\")\n",
    "dfr = pd.read_csv(fi)\n",
    "fi = os.path.join(root, \"raw\", \"yf_amazon\", \"categories.csv\")\n",
    "dfc = pd.read_csv(fi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfp.columns\n",
    "# dfp['name'].unique().tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfp['cate_id_1'] = dfp['catIds'].apply(lambda x: x.split(\",\")[0])\n",
    "for cid1 in dfp['cate_id_1'].unique():\n",
    "    print(dfc[dfc['catId']==int(cid1)]['category'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### dmsc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = time.time()\n",
    "fi = os.path.join(root, \"raw\", \"dmsc\", \"movies.csv\")\n",
    "dfm = pd.read_csv(fi)\n",
    "print(dfm.shape)\n",
    "fi = os.path.join(root, \"raw\", \"dmsc\", \"ratings.csv\")\n",
    "dfr = pd.read_csv(fi)\n",
    "print(dfr.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfr.groupby(\"movieId\", 'rating').count()['comment']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Chinese Classical-Modern"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t1 = time.time()\n",
    "fp = os.path.join(root, \"raw\", \"Classical-Modern\", \"bitext\", \"*\")\n",
    "fo = os.path.join(root, \"chatgpt\", \"processed\", \"chinese_classical.jsonl\")\n",
    "l3 = []\n",
    "dicts = dict()\n",
    "for fi in glob.glob(fp):\n",
    "    name = os.path.basename(fi)\n",
    "    dicts[name] = {\"古文\": [], \"现代文\": []}\n",
    "    with open(fi, \"r\", encoding=\"utf-8\") as r:\n",
    "        for i, line in enumerate(r):\n",
    "            line = line.strip(\"\\n\")\n",
    "            if line.startswith(\"古文\"):\n",
    "                p1 = line[3:]\n",
    "                dicts[name]['古文'].append(p1)\n",
    "            elif line.startswith(\"现代文\"):\n",
    "                p2 = line[4:]\n",
    "                dicts[name]['现代文'].append(p2)\n",
    "            elif p1 is not None and p2 is not None:\n",
    "                pair = [(\"古文\", p1), (\"现代文\", p2)]\n",
    "                random.shuffle(pair)\n",
    "                prompt = f\"{pair[0][0]}:{pair[0][1]}{tokenizer.sep_token}{pair[1][0]}:\"\n",
    "                answers = [{\"answer\": pair[1][1], \"score\": 1}]\n",
    "                l3.append({\"prompt\": prompt, \"answers\": answers, \"name\": name})\n",
    "                p1 = None\n",
    "                p2 = None\n",
    "t2 = time.time()\n",
    "print(f\"length: {len(l3)}, # different names: {len(dicts)}, time taken: {t2-t1} s\")\n",
    "with open(fo, \"w\", encoding=\"utf-8\") as w:\n",
    "    for i, l in tqdm(enumerate(l3), desc=\"Processing Chinese Classical-Modern\"):\n",
    "        name = l['name']\n",
    "        prompt = l['prompt']\n",
    "        answer = l['answers'][0]['answer']\n",
    "        if prompt.startswith(\"古文\"):\n",
    "            answer_type = '现代文'\n",
    "        else:\n",
    "            answer_type = '古文'\n",
    "        samples_tmp = set(dicts[name][answer_type])\n",
    "        samples_tmp.remove(answer)\n",
    "        answers.extend([{\"answer\": fa, \"score\": 0} for fa in random.sample(samples_tmp, 2)])\n",
    "        keys = set(dicts.keys())\n",
    "        keys.remove(name)\n",
    "        answers.extend([{\"answer\": random.choice(dicts[key][answer_type]), \"score\": -1} for key in random.sample(keys, 2)])\n",
    "        w.write(json.dumps({\"prompt\": prompt, \"answers\": answers}, ensure_ascii=False)+'\\n')\n",
    "#         if i % 100 == 0:\n",
    "#             print(f\"{i} samples processed, time taken: {time.time()-t2} s\")\n",
    "print(f\"length: {i}, time taken: {time.time()-t} s\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Chinese Poetry"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import opencc\n",
    "converter = opencc.OpenCC('t2s.json')\n",
    "t1 = time.time()\n",
    "fp = [\n",
    "    # 四书五经\n",
    "    os.path.join(root, \"raw\", \"chinese-poetry\", \"lunyu\", \"lunyu.json\"),\n",
    "#     os.path.join(root, \"raw\", \"chinese-poetry\", \"mengxue\", \"*.json\"),\n",
    "    os.path.join(root, \"raw\", \"chinese-poetry\", \"sishuwujing\", \"*.json\"),\n",
    "    # 古体诗\n",
    "    os.path.join(root, \"raw\", \"chinese-poetry\", \"caocaoshiji\", \"caocao.json\"),\n",
    "    os.path.join(root, \"raw\", \"chinese-poetry\", \"shijing\", \"shijing.json\"),\n",
    "    # 楚辞\n",
    "    os.path.join(root, \"raw\", \"chinese-poetry\", \"chuci\", \"chuci.json\"),\n",
    "    # 诗\n",
    "    os.path.join(root, \"raw\", \"chinese-poetry\", \"shi\", \"poet*.json\"),\n",
    "    # 词\n",
    "    os.path.join(root, \"raw\", \"chinese-poetry\", \"ci\", \"ci*.json\"),\n",
    "    os.path.join(root, \"raw\", \"chinese-poetry\", \"nalanxingde\", \"*.json\"),\n",
    "    os.path.join(root, \"raw\", \"chinese-poetry\", \"wudai\", \"huajianji\", \"*juan.json\"),\n",
    "    os.path.join(root, \"raw\", \"chinese-poetry\", \"wudai\", \"nantang\", \"poetrys.json\"),\n",
    "    # 曲\n",
    "    os.path.join(root, \"raw\", \"chinese-poetry\", \"yuanqu\", \"yuanqu.json\"),\n",
    "]\n",
    "fs = [each for f in fp for each in glob.glob(f)]\n",
    "\n",
    "l5 = []\n",
    "dicts = dict()\n",
    "for fi in fs:\n",
    "    lines = json.load(open(fi, \"r\", encoding=\"utf-8\"))\n",
    "    if isinstance(lines, dict):\n",
    "        lines = [lines]\n",
    "    for i, line in enumerate(lines):\n",
    "        if \"lunyu\" in fi:\n",
    "            author = \"孔子\"\n",
    "            genre = \"经书\"\n",
    "            title = line['chapter']\n",
    "            contents = \"\".join(line['paragraphs'])\n",
    "        elif \"daxue\" in fi:\n",
    "            author = \"曾子\"\n",
    "            genre = \"经书\"\n",
    "            title = \"大学\"\n",
    "            contents = converter.convert(\"\".join(line['paragraphs'])).replace(\"「\", \"“\").replace(\"」\", \"”\")\n",
    "        elif \"mengzi\" in fi:\n",
    "            author = \"孟子\"\n",
    "            genre = \"经书\"\n",
    "            title = converter.convert(line['chapter'])\n",
    "            contents = converter.convert(\"\".join(line['paragraphs'])).replace(\"「\", \"“\").replace(\"」\", \"”\")\n",
    "        elif \"zhongyong\" in fi:\n",
    "            author = \"孔伋\"\n",
    "            genre = \"经书\"\n",
    "            title = \"中庸\"\n",
    "            contents = converter.convert(\"\".join(line['paragraphs'])).replace(\"「\", \"“\").replace(\"」\", \"”\")\n",
    "        elif \"caocao\" in fi:\n",
    "            author = \"曹操\"\n",
    "            genre = \"古体诗\"\n",
    "            title = line['title']\n",
    "            contents = \"\".join(line['paragraphs'])\n",
    "        elif \"shijing\" in fi:\n",
    "            author = \"诗经\"\n",
    "            genre = \"古体诗\"\n",
    "            title = line['chapter'] + \"-\" + line['section'] + \"-\" + line['title']\n",
    "            contents = \"\".join(line['content'])\n",
    "        elif \"chuci\" in fi:\n",
    "            author = line['author']\n",
    "            genre = \"楚辞\"\n",
    "            title = line['section'] + \"-\" + line['title']\n",
    "            contents = \"\".join(line['content'])\n",
    "        elif \"nalanxingde\" in fi:\n",
    "            author = line['author']\n",
    "            genre = \"词\"\n",
    "            title = line['title']\n",
    "            contents = \"\".join(line['para'])\n",
    "        elif \"huajianci\" in fi:\n",
    "            author = line['author']\n",
    "            genre = \"词\"\n",
    "            title = line['title']\n",
    "            contents = \"\".join(line['paragraphs'])\n",
    "        elif \"nantang\" in fi:\n",
    "            author = line['author']\n",
    "            genre = \"词\"\n",
    "            title = line['title']\n",
    "            contents = \"\".join(line['paragraphs'])\n",
    "        elif \"yuanqu\" in fi:\n",
    "            author = line['author']\n",
    "            genre = \"曲\"\n",
    "            title = line['title']\n",
    "            contents = \"\".join(line['paragraphs'])\n",
    "        elif \"shi\" in fi:\n",
    "            if len(line['paragraphs']) <= 0:\n",
    "                continue\n",
    "            author = converter.convert(line['author'])\n",
    "            genre = \"五言诗\" if len(line['paragraphs'][0]) == 12 else \"七言诗\"\n",
    "            title = converter.convert(line['title'])\n",
    "            contents = converter.convert(\"\".join(line['paragraphs']))\n",
    "        elif \"ci\" in fi:\n",
    "            author = line['author']\n",
    "            genre = \"词\"\n",
    "            title = line['rhythmic']\n",
    "            contents = \"\".join(line['paragraphs'])\n",
    "        if genre not in dicts:\n",
    "            dicts[genre] = dict()\n",
    "        if author not in dicts[genre]:\n",
    "            dicts[genre][author] = dict()\n",
    "        quantifier = \"篇\" if genre in [\"经书\", \"楚辞\"] else \"首\"\n",
    "        prompt = f\"以{author}的风格,写一{quantifier}{genre},题为{title}{tokenizer.sep_token}\"\n",
    "        answers = [{\"answer\": contents, \"score\": 1}]\n",
    "        l5.append({\"prompt\": prompt, \"answers\": answers, \"genre\": genre, \"title\": title, \"author\": author})\n",
    "        dicts[genre][author][title] = contents\n",
    "        \n",
    "t2 = time.time()\n",
    "print(f\"length: {len(l5)}, # different lengths: {len(dicts)}, time taken: {t2-t1} s\")\n",
    "fo = os.path.join(root, \"chatgpt\", \"processed\", \"chinese_poetry.jsonl\")\n",
    "with open(fo, \"w\", encoding=\"utf-8\") as w:\n",
    "    for i, l in tqdm(enumerate(l5), desc=\"Processing Chinese Poetry\"):\n",
    "        genre = l['genre']\n",
    "        author = l['author']\n",
    "        title = l['title']\n",
    "        prompt = l['prompt']\n",
    "        answers = l['answers']\n",
    "        # 同作者其他作品-2\n",
    "        titles_tmp = set(dicts[genre][author].keys())\n",
    "        titles_tmp.remove(title)\n",
    "        if len(titles_tmp) > 0:\n",
    "            t = random.choice(list(titles_tmp))\n",
    "            answers.append({\"answer\": dicts[genre][author][t], \"score\": 0})\n",
    "        # 同体裁其他作者其他作品-1\n",
    "        authors_tmp = set(dicts[genre].keys())\n",
    "        authors_tmp.remove(author)\n",
    "        a = random.choice(list(authors_tmp))\n",
    "        t = random.choice(list(dicts[genre][a].keys()))\n",
    "        answers.append({\"answer\": dicts[genre][a][t], \"score\": -1})\n",
    "        # 不同体裁作品-0\n",
    "        genres_tmp = set(dicts.keys())\n",
    "        genres_tmp.remove(genre)\n",
    "        g = random.choice(list(genres_tmp))\n",
    "        a = random.choice(list(dicts[g].keys()))\n",
    "        t = random.choice(list(dicts[g][a].keys()))\n",
    "        answers.append({\"answer\": dicts[g][a][t], \"score\": -2})\n",
    "        w.write(json.dumps({\"prompt\": prompt, \"answers\": answers}, ensure_ascii=False)+'\\n')\n",
    "print(f\"length: {i}, time taken: {time.time()-t2} s\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### baike_qa_2019"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fs = glob.glob(os.path.join(root, \"raw\", \"baike_qa2019\", \"baike_qa_*.json\"))\n",
    "fo = os.path.join(root, \"chatgpt\", \"processed\", \"baike_qa.jsonl\")\n",
    "ct = 0\n",
    "# items = []\n",
    "# lens_prompt = []\n",
    "# lens_label = []\n",
    "with open(fo, \"w\", encoding=\"utf-8\") as w:\n",
    "    for f in fs:\n",
    "        with open(f, \"r\", encoding=\"utf-8\") as r:\n",
    "            while True:\n",
    "                line = r.readline()\n",
    "                if not line:\n",
    "                    break\n",
    "                item = json.loads(line.strip(\"\\n\"))\n",
    "                question = clean_text(item['title'] if len(item['title']) > len(item['desc']) else item['desc'])\n",
    "                prompt = f\"{question}{tokenizer.sep_token}回答:\"\n",
    "                answer = clean_text(item['answer'])\n",
    "                answers = [{\"answer\": answer, \"score\": 1}]\n",
    "#                 items.append(item)\n",
    "#                 lens_prompt.append(len(prompt))\n",
    "#                 lens_label.append(len(label))\n",
    "                w.write(json.dumps({\"prompt\": prompt, \"answers\": answers}, ensure_ascii=False)+'\\n')\n",
    "                ct += 1\n",
    "print(ct)\n",
    "# print(len(items))\n",
    "# print(np.percentile(lens_prompt, np.arange(90, 101)))\n",
    "# print(np.percentile(lens_label, np.arange(90, 101)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### rm-static"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fi = os.path.join(root, \"raw\", \"rm-static\", \"data\", \"test-00000-of-00001-bf4c733542e35fcb.parquet\")\n",
    "df = pd.read_parquet(fi)\n",
    "print(df.shape)\n",
    "df.head(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ChatGLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModel\n",
    "model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True).half().cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "texts = [\n",
    "    \"你好\",\n",
    "    \"晚上睡不着应该怎么办\"\n",
    "]\n",
    "history = []\n",
    "for text in texts:\n",
    "    response, history = model.chat(tokenizer, text, history=history)\n",
    "    print(f\"问: {text}\\n答:{response}\\n\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}


================================================
FILE: notebook/reward.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7d1682d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(0, \"/root/autodl-tmp/Code/RLHF\")\n",
    "sys.path.insert(0, \"/Users/zeyesun/Documents/Code/RLHF\")\n",
    "sys.path.insert(0, \"D:\\\\Code\\\\RLHF\")\n",
    "sys.path.insert(0, \"/mnt/sfevol775196/sunzeye273/Code/chatgpt\")\n",
    "sys.path.insert(0, \"/mnt/share-pa002-vol682688-prd/sunzeye273/Code/chatgpt\")\n",
    "sys.path.insert(0, \"/mnt/pa002-28359-vol543625-private/Code/chatgpt\")\n",
    "\n",
    "import os, time, re, random, glob, json, jieba, copy\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "from transformers import (\n",
    "    AutoModelForCausalLM,\n",
    "    AutoModelForSeq2SeqLM,\n",
    "    AutoModelForMultipleChoice,\n",
    "    AutoTokenizer,\n",
    "    Trainer,\n",
    "    TrainingArguments,\n",
    "    default_data_collator,\n",
    "    TextGenerationPipeline\n",
    ")\n",
    "\n",
    "from src.models.reward import RewardModel\n",
    "\n",
    "device=\"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "from sys import platform\n",
    "if platform == \"linux\" or platform == \"linux2\":\n",
    "    # linux\n",
    "    root = \"/mnt/sfevol775196/sunzeye273/Data\"\n",
    "#     root = \"/mnt/share-pa002-vol682688-prd/sunzeye273/Data\"\n",
    "#     root = \"/mnt/pa002-28359-vol543625-private/Data\"\n",
    "#     root = \"/root/autodl-tmp/Data\"\n",
    "elif platform == \"darwin\":\n",
    "    # OS X\n",
    "    root = \"/Users/zeyesun/Documents/Data\"\n",
    "elif platform == \"win32\":\n",
    "    # Windows...\n",
    "    root = \"D:\\\\Data\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a67e7e4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoConfig\n",
    "model_name = \"pangu-small\"\n",
    "model_name_or_path = os.path.join(root, \"models\", model_name)\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_cache=False, trust_remote_code=True)\n",
    "config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)\n",
    "model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)\n",
    "f = os.path.join(model_name_or_path, \"pytorch_model.bin\")\n",
    "torch.save(model.state_dict(), f)\n",
    "\n",
    "model.config.lora_rank = 0\n",
    "reward_model = RewardModel(model.config, model.transformer, tokenizer)\n",
    "reward_checkpoint = os.path.join(root, \"chatgpt\", \"output\", \"reward\", model_name, \"pytorch_model.bin\")\n",
    "torch.save(reward_model.state_dict(), reward_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "078a3631",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model_name = \"pangu-350M\"\n",
    "model_name = \"pangu-small\"\n",
    "# model_name = \"chatglm-6B\"\n",
    "model_name_or_path = os.path.join(root, \"models\", model_name)\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_cache=False, trust_remote_code=True)\n",
    "print(tokenizer.special_tokens_map)\n",
    "print(tokenizer.all_special_ids)\n",
    "print(\n",
    "    f\"unk: {tokenizer.unk_token_id}\\n\",\n",
    "    f\"pad: {tokenizer.pad_token_id}\\n\",\n",
    "    f\"bos: {tokenizer.bos_token_id}\\n\",\n",
    "    f\"eos: {tokenizer.eos_token_id}\\n\",\n",
    "    f\"sep: {tokenizer.sep_token_id}\\n\",\n",
    "    f\"mask: {tokenizer.mask_token_id}\\n\",\n",
    "#     f\"eop: {tokenizer.eop_token_id}\\n\"\n",
    "#     f\"sop: {tokenizer.sop_token_id}\\n\"\n",
    "#     f\"cls: {tokenizer.cls_token_id}\"\n",
    ") "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "417b6ada",
   "metadata": {},
   "outputs": [],
   "source": [
    "lora_rank = 0\n",
    "lora_alpha = 1\n",
    "lora_train_bias = \"none\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "753d0550",
   "metadata": {},
   "outputs": [],
   "source": [
    "if \"pangu\" in model_name_or_path:\n",
    "    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, use_cache=False, trust_remote_code=True)\n",
    "    model.resize_token_embeddings(tokenizer.vocab_size)\n",
    "    model.config.lora_rank = lora_rank\n",
    "    model.config.lora_alpha = lora_alpha\n",
    "    model.config.lora_train_bias = lora_train_bias\n",
    "    # Initialize the reward model from the (supervised) fine-tuned SFT model\n",
    "    reward_model = RewardModel(model.config, model.transformer, tokenizer)\n",
    "elif \"chatglm\" in model_name_or_path:\n",
    "    model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, trust_remote_code=True).half()\n",
    "    model.config.lora_rank = lora_rank\n",
    "    model.config.lora_alpha = lora_alpha\n",
    "    model.config.lora_train_bias = lora_train_bias\n",
    "    # Initialize the reward model from the (supervised) fine-tuned SFT model\n",
    "    reward_model = RewardModel(model.config, model.transformer, tokenizer).half()\n",
    "    # reward_model = RewardModelWithLoRA(model.config, model.glm, tokenizer)\n",
    "#     layers = reward_model.transformer.layers\n",
    "elif \"glm\" in model_name_or_path:\n",
    "    model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, trust_remote_code=True)\n",
    "    model.config.lora_rank = lora_rank\n",
    "    model.config.lora_alpha = lora_alpha\n",
    "    model.config.lora_train_bias = lora_train_bias\n",
    "    # Initialize the reward model from the (supervised) fine-tuned SFT model\n",
    "    reward_model = RewardModel(model.config, model.glm, tokenizer)\n",
    "    # reward_model = RewardModelWithLoRA(model.config, model.glm, tokenizer)\n",
    "#     layers = reward_model.transformer.transformer.layers\n",
    "else:\n",
    "    raise ValueError(f\"Unsupported model name: {model_name_or_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b18dc5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# reward_checkpoint = os.path.join(root, \"chatgpt\", \"output\", \"reward\", model_name, \"checkpoint-200549\", \"pytorch_model*.bin\")\n",
    "reward_checkpoint = os.path.join(root, \"chatgpt\", \"output\", \"reward\", model_name, \"pytorch_model*.bin\")\n",
    "checkpoints = glob.glob(reward_checkpoint)\n",
    "st = dict()\n",
    "for checkpoint in checkpoints:\n",
    "    st.update(torch.load(checkpoint, map_location=\"cpu\"))\n",
    "reward_model.load_state_dict(st)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "800f59c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# reward_model = reward_model.half().to(device)\n",
    "reward_model = reward_model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a68f86ad",
   "metadata": {},
   "source": [
    "# Reward Model Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58669674",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.data.data import SFTDataset, PairwiseDataset\n",
    "from torch.utils.data import DataLoader, SequentialSampler\n",
    "class dotdict(dict):\n",
    "    __getattr__ = dict.get\n",
    "    __setattr__ = dict.__setitem__\n",
    "    __delattr__ = dict.__delitem__\n",
    "\n",
    "max_length = 128\n",
    "batch_size = 2\n",
    "args = dotdict({\"model_name_or_path\": model_name_or_path, \"max_length\": max_length})\n",
    "train_filename = os.path.join(root, \"chatgpt\", \"processed\", \"dev_data_external_v1.jsonl\")\n",
    "train_dataset = PairwiseDataset(args, train_filename, tokenizer)\n",
    "sampler = SequentialSampler(train_dataset)\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdab3a67",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "reward_model.train()\n",
    "for batch in train_loader:\n",
    "    print(batch.keys())\n",
    "    chosen_input_ids = batch['chosen_input_ids'].to(device)\n",
    "    chosen_attention_mask = batch['chosen_attention_mask'].to(device) if 'chosen_attention_mask' in batch else None\n",
    "    chosen_position_ids = batch['chosen_position_ids'].to(device) if 'chosen_position_ids' in batch else None\n",
    "    rejected_input_ids = batch['rejected_input_ids'].to(device)\n",
    "    rejected_attention_mask = batch['rejected_attention_mask'].to(device) if 'rejected_attention_mask' in batch else None\n",
    "    rejected_position_ids = batch['rejected_position_ids'].to(device) if 'rejected_position_ids' in batch else None\n",
    "    \n",
    "    chosen_values, chosen_reward = reward_model.reward(chosen_input_ids, chosen_attention_mask, chosen_position_ids)\n",
    "    reject_values, reject_reward = reward_model.reward(rejected_input_ids, rejected_attention_mask, rejected_position_ids)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44da49d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "chosen_end_scores = []\n",
    "rejected_end_scores = []\n",
    "bs = chosen_input_ids.shape[0]\n",
    "if len(chosen_input_ids.shape) == 3:\n",
    "    chosen_input_ids = chosen_input_ids.squeeze(1)\n",
    "if len(rejected_input_ids.shape) == 3:\n",
    "    rejected_input_ids = rejected_input_ids.squeeze(1)\n",
    "\n",
    "loss = 0\n",
    "inference = False\n",
    "for i in range(bs):\n",
    "    if torch.all(torch.eq(chosen_input_ids[i], rejected_input_ids[i])).item():\n",
    "        c_inds = (chosen_input_ids[i] == tokenizer.pad_token_id).nonzero()\n",
    "        c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen_input_ids.shape[1]\n",
    "        chosen_end_scores.append(chosen_values[i, c_ind - 1])\n",
    "        inference = True\n",
    "        continue\n",
    "    print(f\"inference: {inference}\")\n",
    "\n",
    "    # Check if there is any padding otherwise take length of sequence\n",
    "    c_inds = (chosen_input_ids[i] == tokenizer.pad_token_id).nonzero()\n",
    "    c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen_input_ids.shape[1]\n",
    "    r_inds = (rejected_input_ids[i] == tokenizer.pad_token_id).nonzero()\n",
    "    r_ind = r_inds[0].item() if len(r_inds) > 0 else rejected_input_ids.shape[1]\n",
    "    end_ind = max(c_ind, r_ind)\n",
    "\n",
    "    # Retrieve first index where trajectories diverge\n",
    "    divergence_ind = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()[0]\n",
    "    assert divergence_ind > 0\n",
    "\n",
    "    # Index into the correct rewards\n",
    "    c_truncated_reward = chosen_values[i][divergence_ind:end_ind]\n",
    "    r_truncated_reward = reject_values[i][divergence_ind:end_ind]\n",
    "\n",
    "    # Append the last rewards to the list of end scores\n",
    "    chosen_end_scores.append(c_truncated_reward[-1])\n",
    "    rejected_end_scores.append(r_truncated_reward[-1])\n",
    "\n",
    "    # Compute loss\n",
    "    loss += -torch.log(torch.sigmoid(c_truncated_reward - r_truncated_reward)).mean()\n",
    "\n",
    "    loss = loss / bs\n",
    "chosen_end_scores = torch.stack(chosen_end_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52095bab",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b02e60d",
   "metadata": {},
   "source": [
    "# Reward Model Eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70abf467",
   "metadata": {},
   "outputs": [],
   "source": [
    "reward_model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f8b4d40",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_length = 512\n",
    "tokenizer.padding_size = \"right\"\n",
    "# tokenizer.padding_size = \"left\"\n",
    "\n",
    "prompt = \"现代文:行三十五里,进入登封县境的耿店。\"\n",
    "prefix = \"古文:\"\n",
    "pred = \"<unk><unk><unk><unk><unk><unk><unk><unk><unk><unk>\"\n",
    "\n",
    "encodings_dict = tokenizer(prompt, prefix+pred, max_length=max_length,\n",
    "                           truncation=\"longest_first\", padding=\"max_length\", return_tensors=\"pt\",\n",
    "                           return_token_type_ids=False, padding_side=\"left\")\n",
    "print(encodings_dict.keys())\n",
    "print(encodings_dict['input_ids'].shape)\n",
    "print(encodings_dict['attention_mask'].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d88344b",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_ids = encodings_dict['input_ids'].to(device)\n",
    "attention_mask = encodings_dict['attention_mask'].to(device)\n",
    "res = reward_model(input_ids, attention_mask)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}


================================================
FILE: notebook/rlhf.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0fa8199",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(0, \"/root/autodl-tmp/Code/RLHF\")\n",
    "sys.path.insert(0, \"/Users/zeyesun/Documents/Code/RLHF\")\n",
    "sys.path.insert(0, \"D:\\\\Code\\\\RLHF\")\n",
    "sys.path.insert(0, \"/mnt/sfevol775196/sunzeye273/Code/chatgpt\")\n",
    "sys.path.insert(0, \"/mnt/share-pa002-vol682688-prd/sunzeye273/Code/chatgpt\")\n",
    "sys.path.insert(0, \"/mnt/pa002-28359-vol543625-private/Code/chatgpt\")\n",
    "\n",
    "import os, time, re, random, glob, json, jieba, copy\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "from transformers import (\n",
    "    AutoModelForCausalLM,\n",
    "    AutoModelForSeq2SeqLM,\n",
    "    AutoModelForMultipleChoice,\n",
    "    AutoTokenizer,\n",
    "    Trainer,\n",
    "    TrainingArguments,\n",
    "    default_data_collator,\n",
    "    TextGenerationPipeline\n",
    ")\n",
    "\n",
    "from src.models.reward import RewardModel\n",
    "\n",
    "device=\"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "from sys import platform\n",
    "if platform == \"linux\" or platform == \"linux2\":\n",
    "    # linux\n",
    "    root = \"/mnt/sfevol775196/sunzeye273/Data\"\n",
    "#     root = \"/mnt/share-pa002-vol682688-prd/sunzeye273/Data\"\n",
    "#     root = \"/mnt/pa002-28359-vol543625-private/Data\"\n",
    "#     root = \"/root/autodl-tmp/Data\"\n",
    "elif platform == \"darwin\":\n",
    "    # OS X\n",
    "    root = \"/Users/zeyesun/Documents/Data\"\n",
    "elif platform == \"win32\":\n",
    "    # Windows...\n",
    "    root = \"D:\\\\Data\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "354bbf38",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model_name = \"pangu-small\"\n",
    "# model_name = \"pangu-350M\"\n",
    "model_name = \"glm-small\"\n",
    "# model_name = \"chatglm-6B\"\n",
    "model_name_or_path = os.path.join(root, \"models\", model_name)\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_cache=False, trust_remote_code=True)\n",
    "print(tokenizer.special_tokens_map)\n",
    "print(tokenizer.all_special_ids)\n",
    "print(\n",
    "    f\"unk: {tokenizer.unk_token_id}\\n\",\n",
    "    f\"pad: {tokenizer.pad_token_id}\\n\",\n",
    "    f\"bos: {tokenizer.bos_token_id}\\n\",\n",
    "    f\"eos: {tokenizer.eos_token_id}\\n\",\n",
    "    f\"sep: {tokenizer.sep_token_id}\\n\",\n",
    "    f\"mask: {tokenizer.mask_token_id}\\n\",\n",
    "    f\"eop: {tokenizer.eop_token_id}\\n\"\n",
    "#     f\"sop: {tokenizer.sop_token_id}\\n\"\n",
    "    f\"cls: {tokenizer.cls_token_id}\"\n",
    ") "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2de5b11",
   "metadata": {},
   "outputs": [],
   "source": [
    "if \"glm\" in model_name_or_path:\n",
    "    model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, trust_remote_code=True)\n",
    "    if \"chatglm\" in model_name_or_path:\n",
    "        model = model.half()\n",
    "else:\n",
    "    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, use_cache=False)\n",
    "    model.resize_token_embeddings(tokenizer.vocab_size)\n",
    "if lora_rank > 0:\n",
    "    convert_to_lora_recursively(model, lora_rank, lora_alpha)\n",
    "    lora.mark_only_lora_as_trainable(model, lora_train_bias)\n",
    "model = model.to(device)\n",
    "model.eval()\n",
    "print(model.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "904753d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"你好,你是谁?\"\n",
    "prefix = \"答:\"\n",
    "label = \"我是ChatGPT\"\n",
    "max_length = 32\n",
    "max_gen_length = 16\n",
    "max_prompt_length = max_length - max_gen_length\n",
    "lora_rank = 0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1484a82",
   "metadata": {},
   "source": [
    "## Dataset Debug"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "469acb98",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.data.data import SFTDataset\n",
    "from torch.utils.data import RandomSampler, DataLoader\n",
    "class dotdict(dict):\n",
    "    \"\"\"dot.notation access to dictionary attributes\"\"\"\n",
    "    __getattr__ = dict.get\n",
    "    __setattr__ = dict.__setitem__\n",
    "    __delattr__ = dict.__delitem__\n",
    "\n",
    "args = {'model_name_or_path': model_name_or_path,\n",
    "        \"max_length\": 128}\n",
    "args = dotdict(args)\n",
    "\n",
    "train_dataset = SFTDataset(args, \"/Users/zeyesun/Documents/Data/chatgpt/processed/test_data_external_v1.jsonl\", \n",
    "                              tokenizer)\n",
    "train_sampler = RandomSampler(train_dataset)\n",
    "train_dataloader = DataLoader(\n",
    "        train_dataset,\n",
    "        sampler=train_sampler,\n",
    "        batch_size=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14b621de",
   "metadata": {},
   "source": [
    "## Generation Debug"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a9d87b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.padding_side = \"left\"\n",
    "inputs = tokenizer(prompt, \n",
    "#                    label, \n",
    "#                    tokenizer.sep_token + prefix,\n",
    "                   max_length=max_prompt_length,\n",
    "                   padding=\"max_length\",\n",
    "                   truncation=\"longest_first\", \n",
    "#                    add_special_tokens=False,\n",
    "                   return_tensors=\"pt\", \n",
    "                   return_token_type_ids=False)\n",
    "print(inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce652f55",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size, prompt_length = inputs['input_ids'].shape\n",
    "\n",
    "with torch.no_grad():\n",
    "#     logger.debug(f\"[_generate_sequence] inputs: {inputs}\")\n",
    "    inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "    seq = model.generate(**inputs, \n",
    "                         max_new_tokens=max_gen_length,\n",
    "                         pad_token_id=tokenizer.pad_token_id,\n",
    "                         do_sample=False,\n",
    "                         num_return_sequences=1,\n",
    "                         top_p=0.9,\n",
    "                         temperature=1.0\n",
    "                        )\n",
    "print(f\"seq: {seq}\")\n",
    "print(tokenizer.batch_decode(seq))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "018d2c5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompts = []\n",
    "for i in range(batch_size):\n",
    "    prompt_ids = seq[i, :prompt_length]\n",
    "    prompt_start_index = (prompt_ids != tokenizer.pad_token_id).nonzero()[0].item()\n",
    "    prompt_ids = seq[i, prompt_start_index:prompt_length]\n",
    "    answer_ids = seq[i, prompt_length:]\n",
    "    prompt = tokenizer.decode(prompt_ids, skip_special_tokens=False)\n",
    "    answer = tokenizer.decode(answer_ids, skip_special_tokens=False)\n",
    "    prompts.append(prompt + answer)\n",
    "print(prompts)\n",
    "outputs = tokenizer(prompts, max_length=max_length,\n",
    "                              truncation=\"longest_first\", padding=\"max_length\",\n",
    "                              return_tensors=\"pt\", return_token_type_ids=False)\n",
    "print(outputs)\n",
    "print(tokenizer.batch_decode(outputs['input_ids']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "429917ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs['input_ids'].device"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22ca3132",
   "metadata": {},
   "source": [
    "## ChatGLM attention mask and position ids Debug"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3308b94",
   "metadata": {},
   "outputs": [],
   "source": [
    "# chatglm build attention mask\n",
    "input_ids = inputs['input_ids']\n",
    "batch_size, seq_length = input_ids.shape\n",
    "context_lengths = [seq.tolist().index(tokenizer.bos_token_id) for seq in input_ids]\n",
    "attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)\n",
    "print(attention_mask.shape)\n",
    "attention_mask.tril_()\n",
    "for i, context_length in enumerate(context_lengths):\n",
    "    attention_mask[i, :, :context_length] = 1\n",
    "print(attention_mask.shape)\n",
    "attention_mask.unsqueeze_(1)\n",
    "print(attention_mask.shape)\n",
    "# attention_mask = (attention_mask < 0.5).bool()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c4fd4e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# chatglm bulid position ids\n",
    "batch_size, seq_length = input_ids.shape\n",
    "context_lengths = [seq.tolist().index(tokenizer.bos_token_id) for seq in input_ids]\n",
    "# if self.position_encoding_2d:\n",
    "position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)\n",
    "# if not gmask:\n",
    "#     for i, context_length in enumerate(context_lengths):\n",
    "#         position_ids[i, context_length:] = mask_positions[i]\n",
    "block_position_ids = [torch.cat((\n",
    "    torch.zeros(context_length, dtype=torch.long, device=device),\n",
    "    torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1\n",
    ")) for context_length in context_lengths]\n",
    "block_position_ids = torch.stack(block_position_ids, dim=0)\n",
    "position_ids = torch.stack((position_ids, block_position_ids), dim=1)\n",
    "# else:\n",
    "#     position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)\n",
    "#     if not gmask:\n",
    "#         for i, context_length in enumerate(context_lengths):\n",
    "#             position_ids[context_length:] = mask_positions[i]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6b29e22",
   "metadata": {},
   "source": [
    "## GLM attention mask and position ids Debug"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b12dda63",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoded_prompt = tokenizer(prompt, prefix + tokenizer.mask_token)\n",
    "prompt_length = len(encoded_prompt['input_ids'])\n",
    "label_length = len(tokenizer.tokenize(label)) + 1\n",
    "if prompt_length + label_length > max_length:\n",
    "    num_tokens_to_remove = prompt_length + label_length - max_length\n",
    "    for _ in range(num_tokens_to_remove):\n",
    "        if prompt_length > label_length:\n",
    "            prompt_length -= 1\n",
    "        else:\n",
    "            label_length -= 1\n",
    "else:\n",
    "    label_length = max_length - prompt_length\n",
    "assert prompt_length > 0\n",
    "assert label_length > 0\n",
    "assert prompt_length + label_length <= max_length\n",
    "inputs = tokenizer(prompt, prefix + tokenizer.mask_token,\n",
    "                              max_length=prompt_length,\n",
    "                              truncation=\"only_first\",\n",
    "                              return_tensors=\"pt\",\n",
    "                              return_attention_mask=True,\n",
    "                              return_token_type_ids=False)\n",
    "inputs_glm = tokenizer.build_inputs_for_generation(inputs, targets=label,\n",
    "                                                          max_gen_length=label_length, padding=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6521cafd",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.padding_side = \"left\"\n",
    "encoded_prompt = tokenizer(prompt, prefix + tokenizer.mask_token)\n",
    "prompt_length = len(encoded_prompt['input_ids'])\n",
    "inputs = tokenizer(prompt, prefix + tokenizer.mask_token,\n",
    "                              max_length=max_prompt_length,\n",
    "                              padding=\"max_length\",\n",
    "                              truncation=\"only_first\",\n",
    "                              return_tensors=\"pt\",\n",
    "                              return_token_type_ids=False)\n",
    "# max_gen_length = max_length - prompt_length\n",
    "# max_gen_length = self.args.max_gen_length\n",
    "assert prompt_length > 0\n",
    "assert max_gen_length > 0\n",
    "assert prompt_length + max_gen_length <= max_length\n",
    "inputs_glm = tokenizer.build_inputs_for_generation(inputs, max_gen_length=max_gen_length, padding=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6344ed6",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(inputs)\n",
    "print(inputs['input_ids'].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fb5447b",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(inputs_glm)\n",
    "print(inputs_glm['input_ids'].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7edfa03",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}


================================================
FILE: notebook/sft.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f526208",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(0, \"/root/autodl-tmp/Code/RLHF\")\n",
    "sys.path.insert(0, \"/Users/zeyesun/Documents/Code/RLHF\")\n",
    "sys.path.insert(0, \"D:\\\\Code\\\\RLHF\")\n",
    "sys.path.insert(0, \"/mnt/sfevol775196/sunzeye273/Code/chatgpt\")\n",
    "sys.path.insert(0, \"/mnt/share-pa002-vol682688-prd/sunzeye273/Code/chatgpt\")\n",
    "sys.path.insert(0, \"/mnt/pa002-28359-vol543625-private/Code/chatgpt\")\n",
    "\n",
    "import os, time, re, random, glob, json, jieba, copy\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "import loralib as lora\n",
    "from transformers import (\n",
    "    AutoModelForCausalLM,\n",
    "    AutoModelForSeq2SeqLM,\n",
    "    AutoModelForMultipleChoice,\n",
    "    AutoTokenizer,\n",
    "    Trainer,\n",
    "    TrainingArguments,\n",
    "    default_data_collator,\n",
    "    TextGenerationPipeline\n",
    ")\n",
    "\n",
    "from src.models.lora import convert_to_lora_recursively\n",
    "\n",
    "device=\"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "from sys import platform\n",
    "if platform == \"linux\" or platform == \"linux2\":\n",
    "    # linux\n",
    "    root = \"/mnt/sfevol775196/sunzeye273/Data\"\n",
    "#     root = \"/mnt/share-pa002-vol682688-prd/sunzeye273/Data\"\n",
    "#     root = \"/mnt/pa002-28359-vol543625-private/Data\"\n",
    "#     root = \"/root/autodl-tmp/Data\"\n",
    "elif platform == \"darwin\":\n",
    "    # OS X\n",
    "    root = \"/Users/zeyesun/Documents/Data\"\n",
    "elif platform == \"win32\":\n",
    "    # Windows...\n",
    "    root = \"D:\\\\Data\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "633e69a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model_name = \"pangu-350M\"\n",
    "# model_name = \"glm-350M-chinese\"\n",
    "model_name = \"chatglm-6B\"\n",
    "model_name_or_path = os.path.join(root, \"models\", model_name)\n",
    "# model_name_or_path = os.path.join(\"/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint\", model_name)\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_cache=False, trust_remote_code=True)\n",
    "print(tokenizer.special_tokens_map)\n",
    "print(tokenizer.all_special_ids)\n",
    "print(\n",
    "    f\"unk: {tokenizer.unk_token_id}\\n\",\n",
    "    f\"pad: {tokenizer.pad_token_id}\\n\",\n",
    "    f\"bos: {tokenizer.bos_token_id}\\n\",\n",
    "    f\"eos: {tokenizer.eos_token_id}\\n\",\n",
    "    f\"sep: {tokenizer.sep_token_id}\\n\",\n",
    "    f\"mask: {tokenizer.mask_token_id}\\n\",\n",
    "#     f\"eop: {tokenizer.eop_token_id}\\n\"\n",
    "#     f\"sop: {tokenizer.sop_token_id}\\n\"\n",
    "#     f\"cls: {tokenizer.cls_token_id}\"\n",
    ") "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c62c599",
   "metadata": {},
   "outputs": [],
   "source": [
    "lora_rank = 0\n",
    "lora_alpha = 1\n",
    "lora_train_bias = \"none\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2cae21a",
   "metadata": {},
   "outputs": [],
   "source": [
    "if \"glm\" in model_name_or_path:\n",
    "    model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, trust_remote_code=True)\n",
    "    if \"chatglm\" in model_name_or_path:\n",
    "        model = model.half()\n",
    "else:\n",
    "    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, use_cache=False)\n",
    "    model.resize_token_embeddings(tokenizer.vocab_size)\n",
    "if lora_rank > 0:\n",
    "    convert_to_lora_recursively(model, lora_rank, lora_alpha)\n",
    "    lora.mark_only_lora_as_trainable(model, lora_train_bias)\n",
    "model.to(device)\n",
    "print(model.device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a28f07a",
   "metadata": {},
   "source": [
    "# SFT Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b003a1c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_files = os.path.join(root, \"chatgpt\", \"output\", \"sft\", \"pangu-350M\", \"checkpoint-57043\", \"pytorch_model*.bin\")\n",
    "# checkpoint_files = os.path.join(root, \"chatgpt\", \"output\", \"sft\", \"pangu-2.6B\", \"pangu-2.6B-sft\", \"pytorch_model*.bin\")\n",
    "checkpoints = glob.glob(checkpoint_files)\n",
    "st = dict()\n",
    "for checkpoint in checkpoints:\n",
    "    st.update(torch.load(checkpoint, map_location=\"cpu\"))\n",
    "model.load_state_dict(st)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9330b5aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_length = 512\n",
    "max_length_generation = 50\n",
    "num_return_sequences = 1\n",
    "top_p = 0.8\n",
    "top_k = 50\n",
    "temperature = 5.0\n",
    "# prompt = '今天晚上我在睡觉.........他想要做那些事..我就大大声骂他\"不要吵我睡觉\"!!!!!...他就跑出去了...还不接我电话'\n",
    "# prompt = \"\"\"阅读文章:《战国无双3》()是由光荣和ω-force开发的战国无双系列的正统第三续作。本作以三大故事为主轴,分别是以武田信玄等人为主的《关东三国志》,织田信长等人为主的《战国三杰》,石田三成等人为主的《关原的年轻武者》,丰富游戏内的剧情。此部份专门介绍角色,欲知武器情报、奥义字或擅长攻击类型等,请至战国无双系列1.由于乡里大辅先生因故去世,不得不寻找其他声优接手。从猛将传 and Z开始。2.战国无双 编年史的原创男女主角亦有专属声优。此模式是任天堂游戏谜之村雨城改编的新增模式。本作中共有20张战场地图(不含村雨城),后来发行的猛将传再新增3张战场地图。但游戏内战役数量繁多,部分地图会有兼用的状况,战役虚实则是以光荣发行的2本「战国无双3 人物真书」内容为主,以下是相关介绍。(注:前方加☆者为猛将传新增关卡及地图。)合并本篇和猛将传的内容,村雨城模式剔除,战国史模式可直接游玩。主打两大模式「战史演武」&「争霸演武」。系列作品外传作品\\n问:《战国无双3》是由哪两个公司合作开发的?\"\"\"\n",
    "# prefix = \"答:\"\n",
    "\n",
    "prompt = \"以王维的风格,写一首五言诗,题为和宋中丞夏日游福贤观天长寺寺即陈左相宅所施之作\"\n",
    "prefix = \"\"\n",
    "\n",
    "# prompt = \"上联:东风执笔点龙睛,看幸福指数,天天向上\"\n",
    "# prefix = \"下联:\"\n",
    "\n",
    "prefix = \"\"\n",
    "while True:\n",
    "    if \"chatglm\" in model_name_or_path:\n",
    "        encoded_prompt = tokenizer(prompt)\n",
    "        prompt_length = len(encoded_prompt['input_ids'])\n",
    "        inputs = tokenizer(prompt, max_length=min(prompt_length, max_length), truncation=\"only_first\",\n",
    "                   return_tensors=\"pt\")\n",
    "        inputs = inputs.to(device)\n",
    "        outputs = model.generate(**inputs,\n",
    "                                 max_new_tokens=max_length_generation,\n",
    "                                 eos_token_id=tokenizer.eop_token_id,\n",
    "                                 pad_token_id=tokenizer.pad_token_id,\n",
    "                                 do_sample=False,\n",
    "                                 num_return_sequences=num_return_sequences,\n",
    "                                 top_p=top_p,\n",
    "                                 temperature=temperature)\n",
    "    elif \"glm\" in model_name_or_path:\n",
    "        encoded_prompt = tokenizer(prompt, prefix + tokenizer.mask_token)\n",
    "        prompt_length = len(encoded_prompt['input_ids'])\n",
    "        inputs = tokenizer(prompt, prefix + tokenizer.mask_token,\n",
    "#                            max_length=max_length - max_length_generation,\n",
    "#                            padding=\"max_length\",\n",
    "                           max_length=min(prompt_length, max_length),\n",
    "                           truncation=\"only_first\",\n",
    "                           return_tensors=\"pt\",\n",
    "                           return_attention_mask=True,\n",
    "                           return_token_type_ids=False)\n",
    "        max_length_generation = max_length - inputs['input_ids'].shape[1]\n",
    "        inputs_glm = tokenizer.build_inputs_for_generation(inputs, \n",
    "                                                           max_gen_length=max_length_generation, padding=True)\n",
    "        inputs_glm = inputs_glm.to(device)\n",
    "        outputs = model.generate(**inputs_glm,\n",
    "                                 max_new_tokens=max_length_generation,\n",
    "                                 eos_token_id=tokenizer.eop_token_id,\n",
    "                                 pad_token_id=tokenizer.pad_token_id,\n",
    "                                 do_sample=False,\n",
    "                                 num_return_sequences=num_return_sequences,\n",
    "                                 top_p=top_p,\n",
    "                                 temperature=temperature)\n",
    "    else:\n",
    "        inputs = tokenizer(prompt, tokenizer.sep_token + prefix, \n",
    "                           max_length=max_length, \n",
    "                           return_tensors=\"pt\",\n",
    "                           truncation=\"only_first\", \n",
    "#                            padding=\"max_length\",\n",
    "                           add_special_tokens=False,\n",
    "                           return_token_type_ids=False)\n",
    "        inputs = inputs.to(device)\n",
    "        outputs = model.generate(**inputs,\n",
    "                                 max_new_tokens=max_length_generation,\n",
    "                                 pad_token_id=tokenizer.pad_token_id,\n",
    "                                 do_sample=True,\n",
    "                                 num_return_sequences=num_return_sequences,\n",
    "#                                  top_p=top_p,\n",
    "                                 top_k=top_k,\n",
    "                                 temperature=temperature)\n",
    "    results = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
    "#     results = [result.split(prefix, maxsplit=1)[1] for result in results]\n",
    "    print(results)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dbb3482",
   "metadata": {},
   "outputs": [],
   "source": [
    "# f = os.path.join(root, \"raw\", \"baike_qa_train.json\")\n",
    "# [baike_qa.jsonl, chinese_classical.jsonl, chinese_poetry.jsonl, couplets.jsonl, weibo_summary_comments.jsonl, zhidao.jsonl]\n",
    "f = os.path.join(root, \"chatgpt\", \"processed\", \"baike_qa.jsonl\")\n",
    "i = 0\n",
    "prompts = []\n",
    "prompts_processed = []\n",
    "labels = []\n",
    "with open(f, \"r\", encoding=\"utf-8\") as r:\n",
    "    while True:\n",
    "        line = r.readline()\n",
    "        if not line:\n",
    "            break\n",
    "        item = json.loads(line.strip(\"\\n\"))\n",
    "        # prompt = clean_text(item['title'] if len(item['title']) > len(item['desc']) else item['desc'])\n",
    "        # prompt_processed = prompt + tokenizer.sep_token + prefix\n",
    "        # label = clean_text(item['answer'])\n",
    "        prompt = item['prompt']\n",
    "        prompt_processed = prompt\n",
    "        label = item['answers'][0]['answer']\n",
    "        prompts.append(prompt)\n",
    "        prompts_processed.append(prompt_processed)\n",
    "        labels.append(label)\n",
    "        i += 1\n",
    "        # if i > 1000:\n",
    "        #     break\n",
    "random.shuffle(prompts_processed)\n",
    "print(len(prompts_processed))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ad64148",
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 79\n",
    "num_return_sequences = 2\n",
    "max_length = 512\n",
    "max_length_generation = 100\n",
    "top_k = 50\n",
    "top_p = 0.8\n",
    "temperature = 1.0\n",
    "t1 = time.time()\n",
    "prompt = prompts_processed[i]\n",
    "inputs = tokenizer(prompt, add_special_tokens=False, return_token_type_ids=False, return_tensors=\"pt\")\n",
    "inputs = inputs.to(device)\n",
    "outputs = model.generate(**inputs,\n",
    "                         max_new_tokens=max_length_generation,\n",
    "                         pad_token_id=tokenizer.pad_token_id,\n",
    "                         do_sample=True,\n",
    "                         num_return_sequences=num_return_sequences,\n",
    "                         # top_p=top_p,\n",
    "                         top_k=top_k,\n",
    "                         temperature=temperature)\n",
    "results = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
    "for result in results:\n",
    "    # result.split(prefix, maxsplit=1)[1]\n",
    "    print(result)\n",
    "\n",
    "\n",
    "# results = text_generator(prompts_processed[i:j], max_length=200, num_return_sequences=num_return_sequences,\n",
    "#                          do_sample=True, top_k=50, temperature=10.0)\n",
    "# print(f\"Finished prediction, time taken: {time.time()-t1}\")\n",
    "\n",
    "# for prompt, res, label in zip(prompts[i:j], results[:(j-i)], labels[i:j]):\n",
    "#     print(f\"prompt: {prompt}\\nlabel: {label}\")\n",
    "#     for k in range(num_return_sequences):\n",
    "#         model_answer = res[k]['generated_text'].split(prefix)[1].replace(\"<eot>\", \"\").replace(\"<pad>\", \"\")\n",
    "#         print(f\"model answer-{k}: {model_answer}\")\n",
    "#     print(\"\\n\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da1e08ce",
   "metadata": {},
   "source": [
    "# SFT Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9eef5821",
   "metadata": {},
   "outputs": [],
   "source": [
    "# prompt = \"你是谁\"\n",
    "# prefix = \"答:\"\n",
    "# label = \"我是***,很高兴为你服务\"\n",
    "prompt = \"\"\"倍数金额: 1倍,¥1024 元 场次 主---------客队 投注选项- 参考赔率---投注选项胜 平 负---\n",
    "第一次为基本面投注---第二次为通过处理后投注  1 伯明翰 VS -----维冈 31-----1.93 3.27 3.87 ---  2伯恩利VS---朴茨茅30----- ---3 博尔顿 VS -----狼队\n",
    "3------1.94 3.25 3.88 ---  4 斯托克 VS ---阿森纳 0------5.03 3.47 1.68 ---  5 门兴 VS -----弗赖堡 31-----1.77 3.倍数金额: 1倍,¥1024 元 场次\n",
    "主---------客队 投注选项- 参考赔率---投注选项胜 平 负--- 第一次为基本面投注---第二次为通过处理后投注  1 伯明翰 VS -----维冈 31-----1.93 3.27 3.87 ---\n",
    "2伯恩利VS---朴茨茅30----- ---3 博尔顿 VS -----狼队 3------1.94 3.25 3.88 ---  4 斯托克 VS ---阿森纳 0------5.03 3.47 1.68 ---  5 门兴 VS -----弗赖堡\n",
    "31-----1.77 3.39 4.43 ---  6 美因兹 VS ---不来梅 10-----3.76 3.34 1.92 ---  7波鸿VS-----纽伦堡30----- ---8 斯图加 VS ---法兰克 31-----1.59 3.62 5.47\n",
    "---  9 赫塔 VS -----霍芬海 30-----2.49 3.19 2.69 ---  10 勒沃 VS ------科隆 3------1.35 4.44 8.31 ---  11卡塔尼VS----巴里31----- ---12 拉齐奥 VS\n",
    "--佛罗伦 31-----2.35 3.05 3.01 ---  13 特内里 VS ----皇马 0------9.43 4.95 1.29 ---  14 巴萨 VS ----马拉加 3------1.15 6.78 15.49 --\"\"\"\n",
    "prefix = \"回答:\"\n",
    "label = \"你出的赔率数据太早了,数据随时都会变化,这就是所谓要看临盘的道理,目前的数据没什么参考价值。\"\n",
    "max_length = 512\n",
    "encoded_prompt = tokenizer(prompt, prefix + tokenizer.mask_token)\n",
    "prompt_length = len(encoded_prompt['input_ids'])\n",
    "label_length = len(tokenizer.tokenize(label)) + (1 if \"chatglm\" not in model_name_or_path else 0)\n",
    "# print(f\"prompt length: {prompt_length}, label length: {label_length}\")\n",
    "if prompt_length + label_length > max_length:\n",
    "    num_tokens_to_remove = prompt_length + label_length - max_length\n",
    "    for _ in range(num_tokens_to_remove):\n",
    "        if prompt_length > label_length:\n",
    "            prompt_length -= 1\n",
    "        else:\n",
    "            label_length -= 1\n",
    "else:\n",
    "    label_length = max_length - prompt_length\n",
    "assert prompt_length > 0\n",
    "assert label_length > 0\n",
    "assert prompt_length + label_length <= max_length\n",
    "encoded_dict = tokenizer(prompt, prefix + tokenizer.mask_token,\n",
    "                         max_length=prompt_length, truncation=\"only_first\",\n",
    "                         return_tensors=\"pt\", return_attention_mask=True)\n",
    "inputs = tokenizer.build_inputs_for_generation(encoded_dict, targets=label,\n",
    "                                               max_gen_length=label_length, padding=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d8b9e45",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(tokenizer.special_tokens_map)\n",
    "print(tokenizer.all_special_ids)\n",
    "print(\n",
    "    f\"unk: {tokenizer.unk_token_id}\\n\",\n",
    "    f\"pad: {tokenizer.pad_token_id}\\n\",\n",
    "    f\"bos: {tokenizer.bos_token_id}\\n\",\n",
    "    f\"eos: {tokenizer.eos_token_id}\\n\",\n",
    "    f\"sep: {tokenizer.sep_token_id}\\n\",\n",
    "    f\"mask: {tokenizer.mask_token_id}\\n\",\n",
    "#     f\"eop: {tokenizer.eop_token_id}\\n\"\n",
    "#     f\"sop: {tokenizer.sop_token_id}\\n\"\n",
    "#     f\"cls: {tokenizer.cls_token_id}\"\n",
    ") "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f99c82e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(tokenizer.convert_ids_to_tokens([20006]))\n",
    "print(tokenizer.convert_ids_to_tokens([20012]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24d349a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(encoded_prompt)\n",
    "print(tokenizer.decode(encoded_prompt['input_ids']))\n",
    "print(encoded_dict)\n",
    "print(tokenizer.batch_decode(encoded_dict['input_ids']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a63f417",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key, val in inputs.items():\n",
    "    print(f\"{key} shape: {val.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f78df96",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key, val in inputs_glm.items():\n",
    "    print(f\"{key} shape: {val.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4261c76d",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(inputs_glm['input_ids'][:20])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6667eaac",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(inputs_glm['labels'][:20])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac236f19",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(inputs_glm['attention_mask'][0][9])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bda9e8ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(inputs_glm['position_ids'][0][:20])\n",
    "print(inputs_glm['position_ids'][1][:20])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83126e24",
   "metadata": {},
   "outputs": [],
   "source": [
    "# st = model.state_dict()\n",
    "st.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7320e1ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(st['transformer.word_embeddings.weight'].dtype)\n",
    "print(st['transformer.layers.0.input_layernorm.weight'].dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81289f7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "dtypes = dict()\n",
    "for key, val in st.items():\n",
    "    if val.dtype not in dtypes:\n",
    "        dtypes[val.dtype] = list()\n",
    "    dtypes[val.dtype].append(key)\n",
    "print(dtypes.keys())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f4e3327",
   "metadata": {},
   "outputs": [],
   "source": [
    "output = model(**inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcc223bf",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}


================================================
FILE: notebook/test.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93c83ce9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, time, re, random, glob, json, jieba, copy\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "from transformers import (\n",
    "    AutoConfig,\n",
    "    AutoModelForCausalLM,\n",
    "    AutoModelForSeq2SeqLM,\n",
    "    AutoModelForMultipleChoice,\n",
    "    AutoTokenizer,\n",
    "    Trainer,\n",
    "    TrainingArguments,\n",
    "    default_data_collator\n",
    ")\n",
    "device=\"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "from sys import platform\n",
    "if platform == \"linux\" or platform == \"linux2\":\n",
    "    # linux\n",
    "    root = \"/mnt/sfevol775196/sunzeye273/Data\"\n",
    "#     root = \"/mnt/share-pa002-vol682688-prd/sunzeye273/Data\"\n",
    "#     root = \"/mnt/pa002-28359-vol543625-private/Data\"\n",
    "#     root = \"/root/autodl-tmp/Data\"\n",
    "elif platform == \"darwin\":\n",
    "    # OS X\n",
    "    root = \"/Users/zeyesun/Documents/Data\"\n",
    "elif platform == \"win32\":\n",
    "    # Windows...\n",
    "    root = \"D:\\\\Data\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d6fdca2",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"llama-7b\"\n",
    "# model_name = \"glm-350M-chinese\"\n",
    "# model_name = \"chatglm-6B\"\n",
    "model_name_or_path = os.path.join(root, \"models\", model_name)\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_cache=False, trust_remote_code=True)\n",
    "print(tokenizer.special_tokens_map)\n",
    "print(tokenizer.all_special_ids)\n",
    "print(\n",
    "    f\"unk: {tokenizer.unk_token_id}\\n\",\n",
    "    f\"pad: {tokenizer.pad_token_id}\\n\",\n",
    "    f\"bos: {tokenizer.bos_token_id}\\n\",\n",
    "    f\"eos: {tokenizer.eos_token_id}\\n\",\n",
    "    f\"sep: {tokenizer.sep_token_id}\\n\",\n",
    "    f\"mask: {tokenizer.mask_token_id}\\n\",\n",
    "#     f\"eop: {tokenizer.eop_token_id}\\n\"\n",
    "#     f\"sop: {tokenizer.sop_token_id}\\n\"\n",
    "    f\"cls: {tokenizer.cls_token_id}\"\n",
    ") "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "52c6acb9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2916.6666666666665"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "35000/ (3*4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df9a2572",
   "metadata": {},
   "source": [
    "# LLaMa"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ccd5025",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"llama-small\"\n",
    "model_name_or_path = os.path.join(root, \"models\", model_name)\n",
    "# tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_cache=False, trust_remote_code=True)\n",
    "config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)\n",
    "model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d76f86cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "f = os.path.join(model_name_or_path, \"pytorch_model.bin\")\n",
    "torch.save(model.state_dict(), f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5405693b",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"你好 </s> 你是谁 </s> 你好 </s> 你是谁\"\n",
    "eos_ids = [0, 5, 12, 16, 22]\n",
    "max_length = 512\n",
    "input_embeds = torch.rand(1, dtype=torch.float16)\n",
    "# tokenizer.tokenize(prompt, add_special_tokens=False)\n",
    "inputs = tokenizer(prompt,  max_length=max_length, return_tensors=\"pt\", truncation=\"longest_first\", \n",
    "          return_attention_mask=True, return_token_type_ids=False)\n",
    "# print(inputs)\n",
    "tgt_len = len(inputs['input_ids'][0])\n",
    "print(f\"tgt_len: {tgt_len}\")\n",
    "combined_attention_mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(input_embeds.dtype).min))\n",
    "# print(f\"shape: {combined_attention_mask.shape}, combined_attention_mask: {combined_attention_mask}\")\n",
    "for i in range(len(eos_ids)-1):\n",
    "    attention_mask = torch.ones((1, eos_ids[i+1]-eos_ids[i]), dtype=torch.long)\n",
    "#     print(f\"{i}, shape: {attention_mask.shape}, attention_mask: {attention_mask}\")\n",
    "    attention_mask = _prepare_decoder_attention_mask(attention_mask, attention_mask.shape, \n",
    "                                    input_embeds, 0)\n",
    "    combined_attention_mask[eos_ids[i]:eos_ids[i+1], eos_ids[i]:eos_ids[i+1]] = attention_mask\n",
    "#     print(f\"{i}-th, shape: {attention_mask.shape}, attention_mask: {attention_mask}\")\n",
    "#     break\n",
    "# print(f\"shape: {combined_attention_mask.shape}, combined_attention_mask: {combined_attention_mask}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "317fe60a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _make_causal_mask(\n",
    "    input_ids_shape, dtype, device, past_key_values_length = 0\n",
    "):\n",
    "    \"\"\"\n",
    "    Make causal mask used for bi-directional self-attention.\n",
    "    \"\"\"\n",
    "    bsz, tgt_len = input_ids_shape\n",
    "    mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)\n",
    "    mask_cond = torch.arange(mask.size(-1), device=device)\n",
    "    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)\n",
    "    mask = mask.to(dtype)\n",
    "\n",
    "    if past_key_values_length > 0:\n",
    "        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)\n",
    "    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)\n",
    "\n",
    "\n",
    "# Copied from transformers.models.bart.modeling_bart._expand_mask\n",
    "def _expand_mask(mask, dtype, tgt_len):\n",
    "    \"\"\"\n",
    "    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\n",
    "    \"\"\"\n",
    "    bsz, src_len = mask.size()\n",
    "    tgt_len = tgt_len if tgt_len is not None else src_len\n",
    "\n",
    "    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\n",
    "\n",
    "    inverted_mask = 1.0 - expanded_mask\n",
    "\n",
    "    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)\n",
    "\n",
    "def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length):\n",
    "        # create causal mask\n",
    "        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n",
    "        combined_attention_mask = None\n",
    "        if input_shape[-1] > 1:\n",
    "            combined_attention_mask = _make_causal_mask(\n",
    "                input_shape,\n",
    "                inputs_embeds.dtype,\n",
    "                device=inputs_embeds.device,\n",
    "                past_key_values_length=past_key_values_length,\n",
    "            )\n",
    "\n",
    "        if attention_mask is not None:\n",
    "            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n",
    "            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(\n",
    "                inputs_embeds.device\n",
    "            )\n",
    "            combined_attention_mask = (\n",
    "                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask\n",
    "            )\n",
    "\n",
    "        return combined_attention_mask"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f589231f",
   "metadata": {},
   "source": [
    "# GLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7268832",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"你好\"\n",
    "prefix = \"答:\"\n",
    "max_length = 128\n",
    "encoded_prompt = tokenizer(prompt, prefix + tokenizer.mask_token)\n",
    "prompt_length = len(encoded_prompt['input_ids'])\n",
    "encoded_dict = tokenizer(prompt, prefix + tokenizer.mask_token,\n",
    "                         max_length=min(prompt_length, max_length),\n",
    "                         truncation=\"only_first\",\n",
    "                         return_tensors=\"pt\",\n",
    "                         return_token_type_ids=False)\n",
    "max_gen_length = max_length - encoded_dict['input_ids'].shape[1]\n",
    "inputs = tokenizer.build_inputs_for_generation(encoded_dict, max_gen_length=max_gen_length, padding=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1fc75f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfd9e66a",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size, seq_length = encoded_dict['input_ids'].shape[:2]\n",
    "attention_mask1 = encoded_dict['attention_mask']\n",
    "attention_mask2 = attention_mask1.unsqueeze(1).expand(-1, seq_length + max_gen_length, -1)\n",
    "generation_attention_mask = torch.cat([attention_mask2.new_zeros((seq_length, max_gen_length)),\n",
    "                                       torch.tril(attention_mask2.new_ones((max_gen_length, max_gen_length)))],\n",
    "                                      dim=0).unsqueeze(0).expand(batch_size, -1, -1)\n",
    "attention_mask3 = torch.cat((attention_mask2, generation_attention_mask), dim=2)\n",
    "attention_mask4 = attention_mask3.unsqueeze(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b629712",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(attention_mask1.shape)\n",
    "print(attention_mask1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fd9341e",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(attention_mask2.shape)\n",
    "# print(attention_mask2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9f75428",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(attention_mask3.shape)\n",
    "print(attention_mask3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b600887",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(attention_mask4.shape)\n",
    "print(attention_mask4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e028572",
   "metadata": {},
   "source": [
    "# ChatGLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c25ba334",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name_or_path = os.path.join(root, \"models\", \"chatglm-6B\")\n",
    "# model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, trust_remote_code=True)\n",
    "# model = model.half().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc68c1cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "texts = [\n",
    "    \"你好\",\n",
    "    \"晚上睡不着应该怎么办\"\n",
    "]\n",
    "history = []\n",
    "for text in texts:\n",
    "    response, history = model.chat(tokenizer, text, history=history)\n",
    "    print(f\"问: {text}\\n答:{response}\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4ccfa2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_length = 512\n",
    "prompt = \"晚上睡不着应该怎么办\"\n",
    "label = \"吃褪黑素\"\n",
    "encoded_dict = tokenizer(prompt, label,\n",
    "                         max_length=max_length, return_tensors=\"pt\",\n",
    "                         truncation=\"longest_first\", padding=\"max_length\", return_token_type_ids=False)\n",
    "print(encoded_dict.keys())\n",
    "print(encoded_dict['input_ids'].shape)\n",
    "print(encoded_dict['input_ids'][0, :50])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50baebfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "ids = [ 94747, 103400,  84213,  86846, 150001, 150004,  20005,  84020, 110857,\n",
    "         84248,  84886,  20003]\n",
    "print([tokenizer._convert_id_to_token(i) for i in ids])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9104b06",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_ids = encoded_dict['input_ids']\n",
    "print(input_ids[0, :20])\n",
    "seqs = input_ids.tolist()\n",
    "mask_positions = [seq.index(tokenizer.gmask_token_id) for seq in seqs]\n",
    "gmask = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7b577cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size, seq_length = input_ids.shape\n",
    "context_lengths = [seq.tolist().index(tokenizer.bos_token_id) for seq in input_ids]\n",
    "attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)\n",
    "attention_mask.tril_()\n",
    "for i, context_length in enumerate(context_lengths):\n",
    "    attention_mask[i, :, :context_length] = 1\n",
    "attention_mask.unsqueeze_(1)\n",
    "# attention_mask = (attention_mask < 0.5).bool()\n",
    "print(attention_mask.shape)\n",
    "print(attention_mask[0, :20, :20])\n",
    "batch_size, seq_length = input_ids.shape\n",
    "context_lengths = [seq.tolist().index(tokenizer.bos_token_id) for seq in input_ids]\n",
    "# if self.position_encoding_2d:\n",
    "position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)\n",
    "if not gmask:\n",
    "    for i, context_length in enumerate(context_lengths):\n",
    "        position_ids[i, context_length:] = mask_positions[i]\n",
    "block_position_ids = [torch.cat((\n",
    "    torch.zeros(context_length, dtype=torch.long, device=device),\n",
    "    torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1\n",
    ")) for context_length in context_lengths]\n",
    "block_position_ids = torch.stack(block_position_ids, dim=0)\n",
    "position_ids = torch.stack((position_ids, block_position_ids), dim=1)\n",
    "# else:\n",
    "#     position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)\n",
    "#     if not gmask:\n",
    "#         for i, context_length in enumerate(context_lengths):\n",
    "#             position_ids[context_length:] = mask_positions[i]\n",
    "print(position_ids.shape)\n",
    "print(position_ids[0, :, :20])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d91e23d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(tokenizer.tokenize(prompt))\n",
    "print(tokenizer(prompt))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5d10f8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoded_prompt = tokenizer(prompt)\n",
    "prompt_length = len(encoded_prompt['input_ids'])\n",
    "inputs = tokenizer(prompt, max_length=min(prompt_length, max_length), truncation=\"only_first\",\n",
    "                   return_tensors=\"pt\")\n",
    "inputs = inputs.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6cd0dfa",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_length_generation = 50\n",
    "num_return_sequences = 1\n",
    "top_p = 0.8\n",
    "temperature = 1.0\n",
    "outputs = model.generate(**inputs, max_new_tokens=max_length_generation,\n",
    "                         eos_token_id=tokenizer.eop_token_id,\n",
    "                         pad_token_id=tokenizer.pad_token_id,\n",
    "                         do_sample=False,\n",
    "                         num_return_sequences=num_return_sequences,\n",
    "                         top_p=top_p,\n",
    "                         temperature=temperature)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f13279f",
   "metadata": {},
   "source": [
    "# DeepSpeed"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7798b57b",
   "metadata": {},
   "source": [
    "### estimate the memory needs for params, optim states and gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07493af2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# stage 1 and 2\n",
    "from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live\n",
    "estimate_zero2_model_states_mem_needs_all_live(model, num_gpus_per_node=3, num_nodes=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cb3b908",
   "metadata": {},
   "outputs": [],
   "source": [
    "# stage 3\n",
    "from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live\n",
    "estimate_zero3_model_states_mem_needs_all_live(model, num_gpus_per_node=3, num_nodes=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9aa232d",
   "metadata": {},
   "source": [
    "# Split torch checkpoint into multiple checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "705292a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint = os.path.join(root, \"chatgpt\", \"output\", \"sft\", \"pangu-2.6B\", \"checkpoint-42782\")\n",
    "if \"glm\" in model_name_or_path:\n",
    "    model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, trust_remote_code=True)\n",
    "else:\n",
    "    model = AutoModelForCausalLM.from_pretrained(checkpoint, trust_remote_code=True, use_cache=False)\n",
    "st = model.state_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6d5ee32",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "print(sys.getsizeof(st))\n",
    "keys = list(st.keys())\n",
    "n = 10\n",
    "m = {\"metadata\": {\"total_size\":sys.getsizeof(st)}, \"weight_map\":dict()}\n",
    "span = len(keys) // n\n",
    "for i in range(n):\n",
    "    fn = f\"pytorch_model-{i+1}-of-{n}.bin\"\n",
    "    f = os.path.join(checkpoint, fn)\n",
    "    stt = dict()\n",
    "    for key in keys[i*span:(i+1)*span]:\n",
    "        stt[key] = st[key]\n",
    "        m[\"weight_map\"][key] = fn\n",
    "    torch.save(stt, f)\n",
    "f = os.path.join(checkpoint, \"pytorch_model.bin.index.json\")\n",
    "json.dump(m, open(f, \"w\", encoding=\"utf-8\"), ensure_ascii=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}


================================================
FILE: requirements.txt
================================================
jieba
sentencepiece
nltk>=3.8.1
rouge-score>=0.1.2
evaluate>=0.4.0
numpy==1.23.2
tensorflow>=2.5.0
transformers>=4.31.0
tokenizers>=0.13.3
#triton==1.0.0
triton
ninja
torchmetrics
opencc
torchtyping
accelerate>=0.19.0
ray
tqdm
rich
icetk
cpm_kernels
protobuf==3.20.0
pynvml
bitsandbytes>=0.39.0
peft>=0.4.0
transformers_stream_generator
trl
trlx; trlx

================================================
FILE: src/__init__.py
================================================


================================================
FILE: src/data/__init__.py
================================================


================================================
FILE: src/data/data.py
================================================
import os
import json
import re
import random
from typing import Tuple, List

import torch
import pandas as pd

from tqdm import tqdm
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from transformers import PreTrainedTokenizerBase

from src.utils import logger, RESOURCE_PATH
from src.utils.modeling_utils import _prepare_decoder_attention_mask, qwen_make_context
from src.utils.file_utils import print_rank_0


def chatglm3_encode(tokenizer: PreTrainedTokenizerBase,
                    query: str,
                    label: str = None,
                    system: str = "",
                    max_length: int = 1024,
                    is_prefix: bool = True
                    ) -> Tuple[List[int], List[int], List[int]]:
    '''Use chatglm3 tokenizer to encode prompt + label with "longest_first" truncation strategy

    :param tokenizer:
    :param prompt:
    :param label:
    :param system:
    :param max_length:
    :return:
    '''
    prefix_tokens = tokenizer.get_prefix_tokens()
    role_tokens_1 = [tokenizer.get_command(f"<|user|>")] + tokenizer.encode(f"\n", add_special_tokens=False)
    # Process `system` and `query`
    if is_prefix:
        system_ids = tokenizer.encode(system + "\n\n", add_special_tokens=False) if len(system) > 0 else []
        query_ids = tokenizer.encode(" " + query, add_special_tokens=False)[1:]
    else:
        system_ids = tokenizer.encode(" \n\n" + system, add_special_tokens=False)[1:] if len(system) > 0 else []
        query_ids = tokenizer.encode(query, add_special_tokens=False)
    # Process `label`
    role_tokens_2 = [tokenizer.get_command(f"<|assistant|>")]
    if label is not None:
        label_ids = tokenizer.encode(label, add_special_tokens=False)
        end_tokens = [tokenizer.get_command("<eos>")]
    else:
        label_ids = []
        end_tokens = []
    # Remove overflowing tokens
    num_tokens_to_remove = len(prefix_tokens) + len(role_tokens_1) + len(query_ids) + len(system_ids) + \
                           len(role_tokens_2) + len(label_ids) + len(end_tokens) - max_length
    if num_tokens_to_remove > 0:
        for _ in range(num_tokens_to_remove):
            if len(query_ids) + len(system_ids) > len(label_ids) and len(query_ids) > 0:
                query_ids.pop()
            elif len(label_ids) > 0:
                label_ids.pop()
            else:
                logger.warn("removing system tokens due to tokens overflowing")
                system_ids.pop()
        if label is not None:
            label_ids += end_tokens
    else:
        if label is not None:
            label_ids += end_tokens
        label_ids += [tokenizer.pad_token_id] * -num_tokens_to_remove

    if is_prefix:
        prompt_ids = prefix_tokens + role_tokens_1 + system_ids + query_ids + role_tokens_2
    else:
        prompt_ids = prefix_tokens + role_tokens_1 + query_ids + system_ids + role_tokens_2
    input_ids = prompt_ids + label_ids
    labels = [tokenizer.pad_token_id] * len(prompt_ids) + label_ids
    assert len(input_ids) == len(labels) == max_length
    return input_ids, labels, prompt_ids


def chatglm2_encode(tokenizer: PreTrainedTokenizerBase,
                    query: str,
                    label: str = None,
                    system: str = "",
                    max_length: int = 1024,
                    is_prefix: bool = True
                    ) -> Tuple[List[int], List[int], List[int]]:
    '''Use chatglm2 tokenizer to encode prompt + label with "longest_first" truncation strategy

    :param tokenizer:
    :param prompt:
    :param label:
    :param system:
    :param max_length:
    :return:
    '''
    gmask_id = tokenizer.get_command("[gMASK]")
    sop_id = tokenizer.get_command("sop")
    eop_id = tokenizer.get_command("eop")
    # [Round {1}]\n\n问:
    ids1 = [790, 30951, 517, 30910, 30939, 30996, 13, 13, 54761, 31211]
    # \n\n答:
    ids2 = [13, 13, 55437, 31211]
    if len(system) > 0:
        if is_prefix:
            system_ids = tokenizer.encode(" " + system + "\n\n", add_special_tokens=False)[1:]
        else:
            system_ids = tokenizer.encode(" \n\n" + system, add_special_tokens=False)[1:]
    else:
        system_ids = []
    query_ids = tokenizer.encode(" " + query, add_special_tokens=False)[1:]
    if label is not None:
        label_ids = tokenizer.encode(label, add_special_tokens=False)
        num_special_tokens = 3
    else:
        label_ids = []
        num_special_tokens = 2
    num_tokens_to_remove = len(ids1) + len(query_ids) + len(system_ids) + len(ids2) + \
                           len(label_ids) + num_special_tokens - max_length
    if num_tokens_to_remove > 0:
        for _ in range(num_tokens_to_remove):
            if len(query_ids) + len(system_ids) > len(label_ids) and len(query_ids) > 0:
                query_ids.pop()
            elif len(label_ids) > 0:
                label_ids.pop()
            else:
                logger.warn("removing system tokens due to tokens overflowing")
                system_ids.pop()
        if label is not None:
            label_ids += [eop_id]
    else:
        if label is not None:
            label_ids += [eop_id]
        label_ids += [tokenizer.pad_token_id] * -num_tokens_to_remove
    if is_prefix:
        prompt_ids = [gmask_id, sop_id] + ids1 + system_ids + query_ids + ids2
    else:
        prompt_ids = [gmask_id, sop_id] + ids1 + query_ids + system_ids + ids2
    input_ids = prompt_ids + label_ids
    labels = [tokenizer.pad_token_id] * len(prompt_ids) + label_ids
    assert len(input_ids) == len(labels) == max_length
    return input_ids, labels, prompt_ids


class DataCollatorReward:
    def __call__(self, data):
        has_attention_mask = 'attention_mask' in data[0]
        batch = {
            "chosen_input_ids": torch.stack([f['input_ids'] for f in data]),
            "chosen_attention_mask": torch.stack([f['attention_mask'] for f in data]) if has_attention_mask else None,
            # "input_ids": torch.cat([f[0] for f in data] + [f[2] for f in data]),
            # "attention_mask": torch.cat([f[1] for f in data] + [f[3] for f in data]),
            # "labels": torch.tensor([0] * len(data) + [1] * len(data))
        }
        return batch


class DataCollatorRLHF:

    def __init__(self, max_token_len, inference_tp_size):
        self.max_token_len = max_token_len
        self.inference_tp_size = inference_tp_size

    def __call__(self, data):
        batch = {}
        pad_token_id = data[-1][-1]

        prompt = pad_sequence([f[0] for f in data],
                              padding_value=pad_token_id,
                              batch_first=True)
        prompt_mask = pad_sequence([f[1] for f in data],
                                   padding_value=0,
                                   batch_first=True)

        ### make sure the final ouput is a seqence of 2**?
        length = prompt.size()[-1]
        pad_length = self.max_token_len - length
        if pad_length > 0:
            batch["prompt"] = F.pad(prompt,
                                    pad=(pad_length, 0),
                                    mode='constant',
                                    value=pad_token_id)
            batch["prompt_att_mask"] = F.pad(prompt_mask,
                                             pad=(pad_length, 0),
                                             mode='constant',
                                             value=0)
        else:
            batch["prompt"] = prompt
            batch["prompt_att_mask"] = prompt_mask
        batch["prompt"] = batch["prompt"].flip(1)
        batch["prompt_att_mask"] = batch["prompt_att_mask"].flip(1)
        return batch


class PretrainDataset(Dataset):
    def __init__(self, args, filename, tokenizer, concat_samples=True):
        self.args = args
        self.tokenizer = tokenizer
        self.concat_samples = concat_samples
        self.model_name_or_path = args.model_name_or_path if hasattr(args,
                                                                     "model_name_or_path") else args.actor_model_path

        self.post_list = self.load_dataset(filename)
        for k in range(5):
            print_rank_0(f"PretrainDataset sample-{k}\n: {self.post_list[k]}")

    def __len__(self):
        return len(self.post_list)

    def __getitem__(self, idx):
        data = self.post_list[idx]
        if not self.concat_samples:
            prompt = data['prompt']
            label = data.get('label', None)
            if "glm" in self.model_name_or_path.lower() and "chatglm" not in self.model_name_or_path.lower():
                encoded_prompt = self.tokenizer(prompt, self.tokenizer.mask_token)
                prompt_length = len(encoded_prompt['input_ids'])
                label_length = len(self.tokenizer.tokenize(label)) + 1
                if prompt_length + label_length > self.args.max_length:
                    num_tokens_to_remove = prompt_length + label_length - self.args.max_length
                    for _ in range(num_tokens_to_remove):
                        if prompt_length > label_length:
                            prompt_length -= 1
                        else:
                            label_length -= 1
                else:
                    label_length = self.args.max_length - prompt_length
                assert prompt_length > 0
                assert label_length > 0
                assert prompt_length + label_length == self.args.max_length
                encoded_dict = self.tokenizer(prompt, self.tokenizer.mask_token,
                                              max_length=prompt_length,
                                              truncation="only_first",
                                              return_tensors="pt",
                                              return_attention_mask=True,
                                              return_token_type_ids=False)
                encoded_dict = self.tokenizer.build_inputs_for_generation(encoded_dict, targets=label,
                                                                          max_gen_length=label_length, padding=True)
                return {
                    "input_ids": encoded_dict['input_ids'][0],
                    "position_ids": encoded_dict['position_ids'][0],
                    "attention_mask": encoded_dict['attention_mask'][0],
                    "labels": encoded_dict['labels'][0],
                }
            else:
                if "chatglm2" in self.model_name_or_path.lower():
                    prompt = f"[Round {1}]\n\n问:{prompt}\n\n答:"
                    label = label
                elif "chatglm" in self.model_name_or_path.lower():
                    prompt = f"[Round {0}]\n问:{prompt}\n答:"
                    label = label
                elif "vicuna" in self.model_name_or_path.lower():
                    prompt += "\n\n" + label
                    label = None
                else:
                    label = None
                encoded_dict = self.tokenizer(prompt, label,
                                              max_length=self.args.max_length,
                                              truncation="longest_first",
                                              padding="max_length",
                                              return_token_type_ids=False,
                                              return_tensors="pt", )
                if "pangu" in self.model_name_or_path.lower():
                    return {
                        "input_ids": encoded_dict['input_ids'],
                        "attention_mask": encoded_dict['attention_mask'],
                        "labels": encoded_dict['input_ids'],
                    }
                else:
                    result = {
                        "input_ids": encoded_dict['input_ids'][0],
                        "labels": encoded_dict['input_ids'][0],
                    }
                    if 'attention_mask' in encoded_dict:
                        result["attention_mask"] = encoded_dict['attention_mask'][0]
                    return result
        else:
            eos_ids = data['eos_ids']
            input_ids = data['input_ids']
            combined_attention_mask = torch.full((self.args.max_length, self.args.max_length),
                                                 torch.tensor(torch.finfo(torch.float16).min))
            for i in range(len(eos_ids) - 1):
                attention_mask = torch.ones((1, eos_ids[i + 1] - eos_ids[i]), dtype=torch.long)
                attention_mask = _prepare_decoder_attention_mask(attention_mask, attention_mask.shape,
                                                                 input_embeds=torch.ones(1, dtype=torch.float16,
                                                                                         device="cpu"),
                                                                 past_key_values_length=0)
                logger.debug(f"{i}-th sample, shape: {attention_mask.shape}, attention_mask: {attention_mask}")
                combined_attention_mask[eos_ids[i]:eos_ids[i + 1], eos_ids[i]:eos_ids[i + 1]] = attention_mask
            logger.debug(f"shape: {combined_attention_mask.shape}, combined_attention_mask: {combined_attention_mask}")
            if "chatglm2" in self.model_name_or_path.lower():
                return {
                    "input_ids": input_ids,
                    "labels": input_ids,
                    "full_attention_mask": combined_attention_mask,
                }
            else:
                return {
                    "input_ids": input_ids,
                    "labels": input_ids,
                    "attention_mask": combined_attention_mask,
                }

    def load_dataset(self, filename):
        discard = 0
        datasets = []
        with open(filename, "r", encoding="utf-8") as f:
            data = []
            eos_ids = [0]
            length = 0
            for i, line in tqdm(enumerate(f), desc=f"Loading {os.path.basename(filename)}"):
                item = json.loads(line)
                prompt = str(item['prompt'])
                label = item.get('label', None)
                if len(prompt) <= 0:
                    discard += 1
                    continue
                if not self.concat_samples:
                    datasets.append({"prompt": prompt, "label": label})
                else:
                    if "chatglm2" not in self.model_name_or_path.lower():
                        assert "glm" not in self.model_name_or_path.lower(), \
                            "Concatenating samples for GLM or ChatGLM not implemented yet"
                    if "chatglm2" in self.model_name_or_path.lower():
                        prompt = f"[Round {1}]\n\n问:{prompt}\n\n答:"
                    else:
                        prompt = prompt if label is None else "\n\n".join((prompt, label))
                        label = None
                    token_ids = self.tokenizer.encode(prompt, label,
                                                      max_length=self.args.max_length - length,
                                                      truncation="longest_first")
                    if length + len(token_ids) < self.args.max_length:
                        data.extend(token_ids)
                        length += len(token_ids)
                        eos_ids.append(length)
                    else:
                        data.extend(token_ids[:(self.args.max_length - length)])
                        eos_ids.append(self.args.max_length)
                        datasets.append({"input_ids": data, "eos_ids": eos_ids})
                        data = []
                        eos_ids = [0]
                        length = 0
        print_rank_0(
            f"Finished loading {os.path.basename(filename)}, # samples: {len(datasets)}, # discarded: {discard}")

        return datasets


class SFTDataset(Dataset):
    def __init__(self, args, filename, tokenizer, concat_samples=True):
        self.args = args
        self.tokenizer = tokenizer
        self.concat_samples = concat_samples
        self.model_name_or_path = args.model_name_or_path if hasattr(args,
                                                                     "model_name_or_path") else args.actor_model_path

        self.post_list = self.load_dataset(filename)
        for k in range(5):
            print_rank_0(f"SFTDataset sample-{k}\n: {self.post_list[k]}")

    def __len__(self):
        return len(self.post_list)

    def __getitem__(self, idx):
        data = self.post_list[idx]
        if not self.concat_samples:
            prompt = data['prompt']
            label = data['label']
            prefix = data['prefix']
            system = data['system']
            if "glm" in self.model_name_or_path.lower() and "chatglm" not in self.model_name_or_path.lower():
                encoded_prompt = self.tokenizer(prompt, prefix + self.tokenizer.mask_token)
                prompt_length = len(encoded_prompt['input_ids'])
                label_length = len(self.tokenizer.tokenize(label)) + 1
                if prompt_length + label_length > self.args.max_length:
                    num_tokens_to_remove = prompt_length + label_length - self.args.max_length
                    for _ in range(num_tokens_to_remove):
                        if prompt_length > label_length:
                            prompt_length -= 1
                        else:
                            label_length -= 1
                else:
                    label_length = self.args.max_length - prompt_length
                assert prompt_length > 0
                assert label_length > 0
                assert prompt_length + label_length == self.args.max_length
                encoded_dict = self.tokenizer(prompt, prefix + self.tokenizer.mask_token,
                                              max_length=prompt_length,
                                              truncation="only_first",
                                              return_tensors="pt",
                                              return_attention_mask=True,
                                              return_token_type_ids=False)
                encoded_dict = self.tokenizer.build_inputs_for_generation(encoded_dict, targets=label,
                                                                          max_gen_length=label_length, padding=True)
                return {
                    "input_ids": encoded_dict['input_ids'][0],
                    "position_ids": encoded_dict['position_ids'][0],
                    "attention_mask": encoded_dict['attention_mask'][0],
                    "labels": encoded_dict['labels'][0],
                }
            elif "pangu" in self.model_name_or_path.lower():
                label = prefix + label
                encoded_dict = self.tokenizer(prompt, label,
                                              max_length=self.args.max_length,
                                              truncation="longest_first",
                                              padding="max_length",
                                              return_token_type_ids=False,
                                              return_tensors="pt", )
                return {
                    "input_ids": encoded_dict['input_ids'],
                    "attention_mask": encoded_dict['attention_mask'],
                    "labels": encoded_dict['input_ids'],
                }
            elif "chatglm3" in self.model_name_or_path.lower():
                input_ids, labels, _ = chatglm3_encode(self.tokenizer, prompt, label, system, self.args.max_length)
                return {
                    "input_ids": torch.tensor(input_ids, dtype=torch.long),
                    # "attention_mask": torch.ones(len(input_ids), dtype=torch.long),
                    "labels": torch.tensor(labels, dtype=torch.long),
                }
            elif "chatglm2" in self.model_name_or_path.lower():
                input_ids, labels, _ = chatglm2_encode(self.tokenizer, prompt, label, system, self.args.max_length)
                # gmask_id = self.tokenizer.get_command("[gMASK]")
                # sop_id = self.tokenizer.get_command("sop")
                # eop_id = self.tokenizer.get_command("eop")
                # # [Round {1}]\n\n问:
                # ids1 = [790, 30951, 517, 30910, 30939, 30996, 13, 13, 54761, 31211]
                # # \n\n答:
                # ids2 = [13, 13, 55437, 31211]
                # prompt = "\n\n".join((system, prompt))
                # prompt_ids = self.tokenizer.encode(" " + prompt, add_special_tokens=False)[1:]
                # label_ids = self.tokenizer.encode(label, add_special_tokens=False)
                # num_tokens_to_remove = len(ids1) + len(prompt_ids) + len(ids2) + len(label_ids) + 3 - self.args.max_length
                # if num_tokens_to_remove > 0:
                #     for _ in range(num_tokens_to_remove):
                #         if len(prompt_ids) > len(label_ids):
                #             prompt_ids.pop()
                #         else:
                #             label_ids.pop()
                #     prompt_ids = [gmask_id, sop_id] + ids1 + prompt_ids + ids2
                #     label_ids = label_ids + [eop_id]
                # else:
                #     prompt_ids = [gmask_id, sop_id] + ids1 + prompt_ids + ids2
                #     label_ids = label_ids + [eop_id] + [self.tokenizer.pad_token_id] * -num_tokens_to_remove
                # input_ids = prompt_ids + label_ids
                # labels = [self.tokenizer.pad_token_id] * len(prompt_ids) + label_ids
                # assert len(input_ids) == len(labels) == self.args.max_length
                return {
                    "input_ids": torch.tensor(input_ids, dtype=torch.long),
                    # "attention_mask": torch.ones(len(input_ids), dtype=torch.long),
                    "labels": torch.tensor(labels, dtype=torch.long),
                }
            elif "chatglm" in self.model_name_or_path.lower():
                prompt = f"[Round {0}]\n问:{prompt}\n答:"
                encoded_dict = self.tokenizer(prompt, label,
                                              max_length=self.args.max_length,
                                              truncation="longest_first",
                                              padding="max_length",
                                              return_token_type_ids=False,
                                              return_tensors="pt", )
                return {
                    "input_ids": encoded_dict['input_ids'][0],
                    "attention_mask": encoded_dict['attention_mask'][0],
                    "labels": encoded_dict['input_ids'][0],
                }
            else:
                encoded_dict = self.tokenizer(prompt, label,
                                              max_length=self.args.max_length,
                                              truncation="longest_first",
                                              padding="max_length",
                                              return_token_type_ids=False,
                                              return_tensors="pt", )
                result = {
                    "input_ids": encoded_dict['input_ids'][0],
                    "labels": encoded_dict['input_ids'][0],
                }
                if 'attention_mask' in encoded_dict:
                    result["attention_mask"] = encoded_dict['attention_mask'][0]
                return result
        else:
            eos_ids = data['eos_ids']
            input_ids = data['input_ids']
            combined_attention_mask = torch.full((self.args.max_length, self.args.max_length),
                                                 torch.tensor(torch.finfo(torch.float16).min))
            for i in range(len(eos_ids) - 1):
                attention_mask = torch.ones((1, eos_ids[i + 1] - eos_ids[i]), dtype=torch.long)
                attention_mask = _prepare_decoder_attention_mask(attention_mask, attention_mask.shape,
                                                                 input_embeds=torch.ones(1, dtype=torch.float16,
                                                                                         device="cpu"),
                                                                 past_key_values_length=0)
                logger.debug(f"{i}-th sample, shape: {attention_mask.shape}, attention_mask: {attention_mask}")
                combined_attention_mask[eos_ids[i]:eos_ids[i + 1], eos_ids[i]:eos_ids[i + 1]] = attention_mask
            logger.debug(f"shape: {combined_attention_mask.shape}, combined_attention_mask: {combined_attention_mask}")
            if "chatglm2" in self.model_name_or_path.lower():
                return {
                    "input_ids": input_ids,
                    "labels": input_ids,
                    "full_attention_mask": combined_attention_mask,
                }
            else:
                return {
                    "input_ids": input_ids,
                    "labels": input_ids,
                    "attention_mask": combined_attention_mask,
                }

    def load_dataset(self, filename):
        discard = 0
        datasets = []
        with open(filename, "r", encoding="utf-8") as f:
            data = []
            eos_ids = [0]
            length = 0
            for i, line in tqdm(enumerate(f), desc=f"Loading {os.path.basename(filename)}"):
                item = json.loads(line)
                data_type = item.get('data_type', "human_generated")
                if data_type != "human_generated":
                    continue
                prompt = str(item['prompt'])
                label = str(item['answers'][0]['answer'])
                score = item['answers'][0]['score']
                prefix = item.get('prefix', "")
                system = item.get('system', "")
                if len(prompt) <= 0 or len(label) <= 0:
                    discard += 1
                    continue

                if not self.concat_samples:
                    datasets.append({"prompt": prompt, "label": label, "prefix": prefix, "system": system})
                else:
                    if "chatglm2" not in self.model_name_or_path.lower():
                        assert "glm" not in self.model_name_or_path.lower(), \
                            "Concatenating samples for GLM or ChatGLM not implemented yet"
                    else:
                        if "chatglm2" in self.model_name_or_path.lower():
                            prompt = f"[Round {1}]\n\n问:{prompt}\n\n答:"
                        else:
                            prompt = prompt if label is None else "\n\n".join((prompt, label))
                            label = None
                        token_ids = self.tokenizer.encode(prompt, label,
                                                          max_length=self.args.max_length - length,
                                                          truncation="longest_first")
                        if length + len(token_ids) < self.args.max_length:
                            data.extend(token_ids)
                            length += len(token_ids)
                            eos_ids.append(length)
                        else:
                            data.extend(token_ids[:(self.args.max_length - length)])
                            eos_ids.append(self.args.max_length)
                            datasets.append({"input_ids": data, "eos_ids": eos_ids})
                            data = []
                            eos_ids = [0]
                            length = 0

        print_rank_0(
            f"Finished loading {os.path.basename(filename)}, # samples: {len(datasets)}, # discarded: {discard}")

        return datasets


class PairwiseDataset(Dataset):
    def __init__(self, args, filename, tokenizer):
        self.pairs = self.load_dataset(filename)
        self.args = args
        self.tokenizer = tokenizer

        for k in range(5):
            print_rank_0(f"PairwiseDataset sample-{k}\n: {self.pairs[k]}")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        pair = self.pairs[idx]

        prompt = pair["prompt"]
        chosen_answer = pair["chosen_answer"]
        rejected_answer = pair["rejected_answer"]
        prefix = pair['prefix']
        system = pair['system']
        if "pangu" in self.args.model_name_or_path.lower():
            chosen_encodings_dict = self.tokenizer(prompt, prefix + chosen_answer, max_length=self.args.max_length,
                                                   truncation="longest_first", padding="max_length",
                                                   return_tensors="pt",
                                                   return_token_type_ids=False)
            rejected_encodings_dict = self.tokenizer(prompt, prefix + rejected_answer, max_length=self.args.max_length,
                                                     truncation="longest_first", padding="max_length",
                                                     return_tensors="pt",
                                                     return_token_type_ids=False)
            return {
                "chosen_input_ids": chosen_encodings_dict["input_ids"],
                "chosen_attention_mask": chosen_encodings_dict["attention_mask"],
                "rejected_input_ids": rejected_encodings_dict["input_ids"],
                "rejected_attention_mask": rejected_encodings_dict["attention_mask"],
                "labels": rejected_encodings_dict["input_ids"],
            }
        elif "chatglm3" in self.args.model_name_or_path.lower():
            chosen_input_ids, labels, _ = chatglm3_encode(self.tokenizer, prompt, chosen_answer, system,
                                                          self.args.max_length)
            rejected_input_ids, labels, _ = chatglm3_encode(self.tokenizer, prompt, rejected_answer, system,
                                                            self.args.max_length)
            return {
                "chosen_input_ids": torch.tensor(chosen_input_ids, dtype=torch.long),
                "rejected_input_ids": torch.tensor(rejected_input_ids, dtype=torch.long),
                "labels": torch.tensor(labels, dtype=torch.long)
            }
        elif "chatglm2" in self.args.model_name_or_path.lower():
            chosen_input_ids, labels, _ = chatglm2_encode(self.tokenizer, prompt, chosen_answer, system,
                                                          self.args.max_length)
            rejected_input_ids, labels, _ = chatglm2_encode(self.tokenizer, prompt, rejected_answer, system,
                                                            self.args.max_length)
            return {
                "chosen_input_ids": torch.tensor(chosen_input_ids, dtype=torch.long),
                "rejected_input_ids": torch.tensor(rejected_input_ids, dtype=torch.long),
                "labels": torch.tensor(labels, dtype=torch.long)
            }
        elif "chatglm" in self.args.model_name_or_path.lower():
            prompt = f"[Round {0}]\n问:{prompt}\n答:"
            chosen_encodings_dict = self.tokenizer(prompt, chosen_answer, max_length=self.args.max_length,
                                                   truncation="longest_first", padding="max_length",
                                                   return_tensors="pt")
            rejected_encodings_dict = self.tokenizer(prompt, rejected_answer, max_length=self.args.max_length,
                                                     truncation="longest_first", padding="max_length",
                                                     return_tensors="pt")
            return {
                "chosen_input_ids": chosen_encodings_dict["input_ids"][0],
                "rejected_input_ids": rejected_encodings_dict["input_ids"][0],
                "labels": rejected_encodings_dict["input_ids"][0],
            }
        elif "glm" in self.args.model_name_or_path.lower():
            chosen_prompt_length = len(self.tokenizer.tokenize(prompt + prefix)) + 4
            rejected_prompt_length = chosen_prompt_length
            chosen_answer_length = len(self.tokenizer.tokenize(chosen_answer)) + 1
            if chosen_prompt_length + chosen_answer_length > self.args.max_length:
                if chosen_prompt_length >= chosen_answer_length:
                    chosen_prompt_length -= chosen_prompt_length + chosen_answer_length - self.args.max_length
                else:
                    chosen_answer_length -= chosen_prompt_length + chosen_answer_length - self.args.max_length
            else:
                chosen_answer_length = self.args.max_length - chosen_prompt_length
            chosen_encoded_dict = self.tokenizer(prompt, prefix + self.tokenizer.mask_token,
                                                 max_length=chosen_prompt_length,
                                                 truncation="only_first",
                                                 return_tensors="pt",
                                                 return_token_type_ids=False)
            chosen_encodings_dict = self.tokenizer.build_inputs_for_generation(chosen_encoded_dict,
                                                                               targets=chosen_answer,
                                                                               max_gen_length=chosen_answer_length,
                                                                               padding=True)

            rejected_answer_length = len(self.tokenizer.tokenize(rejected_answer)) + 1
            if rejected_prompt_length + rejected_answer_length > self.args.max_length:
                if rejected_prompt_length >= rejected_answer_length:
                    rejected_prompt_length -= rejected_prompt_length + rejected_answer_length - self.args.max_length
                else:
                    rejected_answer_length -= rejected_prompt_length + rejected_answer_length - self.args.max_length
            else:
                rejected_answer_length = self.args.max_length - rejected_prompt_length
            rejected_encoded_dict = self.tokenizer(prompt, prefix + self.tokenizer.mask_token,
                                                   max_length=rejected_prompt_length,
                                                   truncation="only_first",
                                                   return_tensors="pt",
                                                   return_token_type_ids=False)
            rejected_encodings_dict = self.tokenizer.build_inputs_for_generation(rejected_encoded_dict,
                                                                                 targets=rejected_answer,
                                                                                 max_gen_length=rejected_answer_length,
                                                                                 padding=True)
            return {
                "chosen_input_ids": chosen_encodings_dict["input_ids"][0],
                "chosen_attention_mask": chosen_encodings_dict["attention_mask"][0],
                "chosen_position_ids": chosen_encodings_dict["position_ids"][0],
                "rejected_input_ids": rejected_encodings_dict["input_ids"][0],
                "rejected_attention_mask": rejected_encodings_dict["attention_mask"][0],
                "rejected_position_ids": rejected_encodings_dict["position_ids"][0],
                "labels": rejected_encodings_dict["input_ids"][0],
            }
        else:
            raise ValueError(f"Unsupported model name: {self.args.model_name_or_path}")

    @staticmethod
    def load_dataset(filename):
        discard = 0
        pairs = []
        with open(filename, "r", encoding="utf-8") as f:
            for line in tqdm(f, desc=f"Loading {os.path.basename(filename)}"):
                item = json.loads(line)
                prompt = str(item['prompt'])
                answers = item['answers']
                prefix = item.get('prefix', "")
                system = item.get('system', "")
                chosen_answer, rejected_answer = None, None
                for i in range(len(answers) - 1):
                    answer_1 = str(answers[i]["answer"])
                    answer_1_score = answers[i]["score"]
                    answer_2 = str(answers[i + 1]["answer"])
                    answer_2_score = answers[i + 1]["score"]
                    if answer_1_score > answer_2_score:
                        chosen_answer = answer_1
                    rejected_answer = answer_2
                    if chosen_answer is not None and rejected_answer is not None \
                            and len(prompt) > 0 and len(chosen_answer) > 0 and len(rejected_answer) > 0 \
                            and chosen_answer != rejected_answer:
                        pair = {
                            "prompt": prompt,
                            "prefix": prefix,
                            "system": system,
                            "chosen_answer": chosen_answer,
                            "rejected_answer": rejected_answer
                        }
                        pairs.append(pair)
                    else:
                        discard += 1

        print_rank_0(f"Finished loading {os.path.basename(filename)}, # pairs: {len(pairs)}, # discarded: {discard}")

        return pairs


class RLHFDataset(Dataset):
    def __init__(self, args, filename, tokenizer):
        self.args = args
        self.tokenizer = tokenizer
        assert tokenizer.padding_side == "left", "In RLHF training, need to set padding_side to 'left'"

        self.post_list = self.load_dataset(filename)
        for k in range(5):
            print_rank_0(f"RLHFDataset sample-{k}\n: {self.post_list[k]}")

    def __len__(self):
        return len(self.post_list)

    def __getitem__(self, idx):
        data = self.post_list[idx]
        prompt = data['prompt']
        prefix = data['prefix']
        system = data['system']
        if "pangu" in self.args.actor_model_path:
            encoded_dict = self.tokenizer(prompt, self.tokenizer.sep_token + prefix,
                                          max_length=self.args.max_prompt_length,
                                          # padding="max_length",
                                          truncation="only_first", add_special_tokens=False,
                                          return_tensors="pt", return_token_type_ids=False)
            return {
                "input_ids": encoded_dict['input_ids'][0],
                "attention_mask": encoded_dict['attention_mask'][0],
                # "labels": encoded_dict['input_ids'],
            }
        elif "chatglm" in self.args.actor_model_path:
            prompt = "\n\n".join((system, prompt))
            prompt = f"[Round {1}]\n\n问:{prompt}\n\n答:" if "chatglm2" in self.args.actor_model_path else f"[Round {0}]\n问:{prompt}\n答:"
            encoded_dict = self.tokenizer(prompt, max_length=self.args.max_prompt_length,
                                          return_tensors="pt", truncation="only_first")
            return {
                "input_ids": encoded_dict['input_ids'][0],
            }
        elif "glm" in self.args.actor_model_path:
            # encoded_prompt = self.tokenizer(prompt, prefix + self.tokenizer.mask_token)
            # prompt_length = len(encoded_prompt['input_ids'])
            encoded_dict = self.tokenizer(prompt, prefix + self.tokenizer.mask_token,
                                          max_length=self.args.max_prompt_length,
                                          # padding="max_length",
                                          truncation="only_first",
                                          return_tensors="pt",
                                          return_token_type_ids=False)
            encoded_dict = self.tokenizer.build_inputs_for_generation(encoded_dict,
                                                                      max_gen_length=self.args.max_gen_length,
                                                                      padding=True)

            return {
                "input_ids": encoded_dict['input_ids'][0],
                "position_ids": encoded_dict['position_ids'][0],
                "generation_attention_mask": encoded_dict['generation_attention_mask'][0],
                # "labels": encoded_dict['labels'][0],
            }
        else:
            raise ValueError(f"Unsupported model name: {self.args.model_name_or_path}")

    @staticmethod
    def load_dataset(filename):
        discard = 0
        datasets = []
        with open(filename, "r", encoding="utf-8") as f:
            for i, line in tqdm(enumerate(f), desc=f"Loading {os.path.basename(filename)}"):
                item = json.loads(line)
                data_type = item.get('data_type', "human_generated")
                if data_type != "human_generated":
                    continue
                prompt = str(item['prompt'])
                prefix = item.get('prefix', "")
                system = item.get('system', "")

                if len(prompt) <= 0:
                    discard += 1
                    continue
                datasets.append({"prompt": prompt, "system": system, "prefix": prefix})
        print_rank_0(
            f"Finished loading {os.path.basename(filename)}, # samples: {len(datasets)}, # discarded: {discard}")

        return datasets


class PPODataset:
    def __init__(self, max_size, small_batch_size):
        self.dataset = []
        self.max_size = max_size
        self.small_batch_size = small_batch_size

    def separate(self):
        small_dataset = []
        for large_batch in self.dataset:
            if type(large_batch) == list or type(large_batch) == tuple:
                large_size = len(large_batch[0])
            elif type(large_batch) == dict:
                large_size = len(large_batch[list(large_batch.keys())[0]])
            else:
                large_size = len(large_batch)
            for i in range(0, large_size, self.small_batch_size):
                if type(large_batch) == list or type(large_batch) == tuple:
                    small_dataset.append(
                        [x[i:i + self.small_batch_size] for x in large_batch])
                elif type(large_batch) == dict:
                    small_dataset.append({
                        k: v[i:i + self.small_batch_size] if v is not None else None
                        for k, v in large_batch.items()
                    })
                else:
                    small_dataset.append(large_batch[i:i + self.small_batch_size])
        self.free()

        return small_dataset

    def add(self, data):
        if len(self.dataset) < self.max_size:
            self.dataset.append(data)
            if len(self.dataset) == self.max_size:
                return self.separate()
            else:
                return None
        else:
            raise ValueError(
                "The dataset is full but we did not stop it. There is a bug in the code."
            )

    def free(self):
        self.dataset = []


class DPODataset(Dataset):
    def __init__(self, args, filename, tokenizer):
        self.pairs = self.load_dataset(filename)
        self.args = args
        self.tokenizer = tokenizer

        for k in range(5):
            print_rank_0(f"DPODataset sample-{k}\n: {self.pairs[k]}")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        pair = self.pairs[idx]

        index = pair["index"]
        prompt = pair["prompt"]
        chosen_answer = pair["chosen_answer"]
        rejected_answer = pair["rejected_answer"]
        prefix = pair['prefix']
        system = pair['system']
        if "pangu" in self.args.model_name_or_path.lower():
            chosen_encodings_dict = self.tokenizer(prompt, prefix + chosen_answer, max_length=self.args.max_length,
                                                   truncation="longest_first", padding="max_length",
                                                   return_tensors="pt",
                                                   return_token_type_ids=False)
            rejected_encodings_dict = self.tokenizer(prompt, prefix + rejected_answer, max_length=self.args.max_length,
                                                     truncation="longest_first", padding="max_length",
                                                     return_tensors="pt",
                                                     return_token_type_ids=False)
            return {
                "chosen_input_ids": chosen_encodings_dict["input_ids"],
                "chosen_attention_mask": chosen_encodings_dict["attention_mask"],
                "rejected_input_ids": rejected_encodings_dict["input_ids"],
                "rejected_attention_mask": rejected_encodings_dict["attention_mask"],
                "labels": rejected_encodings_dict["input_ids"],
            }
        elif "chatglm3" in self.args.model_name_or_path.lower():
            chosen_input_ids, chosen_labels, _ = chatglm3_encode(self.tokenizer, prompt, chosen_answer, system,
                                                                 self.args.max_length)
            rejected_input_ids, rejected_labels, _ = chatglm3_encode(self.tokenizer, prompt, rejected_answer, system,
                                                                     self.args.max_length)
            return {
                "index": torch.tensor(index, dtype=torch.long),
                "chosen_input_ids": torch.tensor(chosen_input_ids, dtype=torch.long),
                "rejected_input_ids": torch.tensor(rejected_input_ids, dtype=torch.long),
                "chosen_labels": torch.tensor(chosen_labels, dtype=torch.long),
                "rejected_labels": torch.tensor(rejected_labels, dtype=torch.long)
            }
        elif "chatglm2" in self.args.model_name_or_path.lower():
            chosen_input_ids, chosen_labels, _ = chatglm2_encode(self.tokenizer, prompt, chosen_answer, system,
                                                                 self.args.max_length)
            rejected_input_ids, rejected_labels, _ = chatglm2_encode(self.tokenizer, prompt, rejected_answer, system,
                                                                     self.args.max_length)
            return {
                "index": torch.tensor(index, dtype=torch.long),
                "chosen_input_ids": torch.tensor(chosen_input_ids, dtype=torch.long),
                "rejected_input_ids": torch.tensor(rejected_input_ids, dtype=torch.long),
                "chosen_labels": torch.tensor(chosen_labels, dtype=torch.long),
                "rejected_labels": torch.tensor(rejected_labels, dtype=torch.long)
            }
        elif "chatglm" in self.args.model_name_or_path.lower():
            prompt = f"[Round {0}]\n问:{prompt}\n答:"
            chosen_encodings_dict = self.tokenizer(prompt, chosen_answer, max_length=self.args.max_length,
                                                   truncation="longest_first", padding="max_length",
                                                   return_tensors="pt")
            rejected_encodings_dict = self.tokenizer(prompt, rejected_a
Download .txt
gitextract_138ju87y/

├── README.md
├── docs/
│   └── INSTALL_ON_WINDOWS.md
├── examples/
│   ├── data_prepare.sh
│   ├── eval_pretrain.sh
│   ├── eval_reward.sh
│   ├── pred_pretrain.sh
│   ├── pred_reward.sh
│   ├── pretrain.sh
│   ├── train_dpo.sh
│   ├── train_reward.sh
│   ├── train_rlhf.sh
│   └── train_sft.sh
├── notebook/
│   ├── data_prepare.ipynb
│   ├── reward.ipynb
│   ├── rlhf.ipynb
│   ├── sft.ipynb
│   └── test.ipynb
├── requirements.txt
└── src/
    ├── __init__.py
    ├── data/
    │   ├── __init__.py
    │   ├── data.py
    │   ├── data_types.py
    │   └── pipeline.py
    ├── data_prepare.py
    ├── eval_pretrain.py
    ├── models/
    │   ├── __init__.py
    │   ├── loss.py
    │   ├── ppo.py
    │   ├── reward.py
    │   ├── rlhf_engine.py
    │   └── trainer.py
    ├── pretrain.py
    ├── pretrain_wo_trainer.py
    ├── resources/
    │   ├── config/
    │   │   ├── deepspeed/
    │   │   │   ├── Experiment.md
    │   │   │   ├── README.md
    │   │   │   ├── stage-1.json
    │   │   │   ├── stage-2.json
    │   │   │   ├── stage-3-no_trainer.json
    │   │   │   └── stage-3.json
    │   │   ├── log.conf
    │   │   └── ppo_model/
    │   │       ├── ppo_config_glm-350M-chinese.yml
    │   │       ├── ppo_config_glm-small.yml
    │   │       ├── ppo_config_pangu-2.6B.yml
    │   │       └── ppo_config_pangu-350M.yml
    │   ├── dict/
    │   │   └── jieba_custom_dict.txt
    │   ├── eval/
    │   │   ├── ceval/
    │   │   │   ├── subject_mapping.json
    │   │   │   └── submission_example.json
    │   │   └── mmlu/
    │   │       └── subject_mapping.json
    │   └── models/
    │       ├── baichuan-13B-base/
    │       │   ├── config.json
    │       │   ├── configuration_baichuan.py
    │       │   ├── modeling_baichuan.py
    │       │   ├── tokenization_baichuan.py
    │       │   └── tokenizer_config.json
    │       ├── baichuan-13B-chat/
    │       │   ├── config.json
    │       │   ├── configuration_baichuan.py
    │       │   ├── modeling_baichuan.py
    │       │   ├── tokenization_baichuan.py
    │       │   └── tokenizer_config.json
    │       ├── baichuan-7B/
    │       │   ├── config.json
    │       │   ├── configuration_baichuan.py
    │       │   ├── modeling_baichuan.py
    │       │   ├── tokenization_baichuan.py
    │       │   └── tokenizer_config.json
    │       ├── chatglm-6B/
    │       │   ├── README.md
    │       │   ├── config.json
    │       │   ├── configuration_chatglm.py
    │       │   ├── modeling_chatglm.py
    │       │   ├── quantization.py
    │       │   ├── tokenization_chatglm.py
    │       │   └── tokenizer_config.json
    │       ├── chatglm2-6B/
    │       │   ├── config.json
    │       │   ├── configuration_chatglm.py
    │       │   ├── modeling_chatglm.py
    │       │   ├── quantization.py
    │       │   ├── tokenization_chatglm.py
    │       │   └── tokenizer_config.json
    │       ├── chatglm3-6B/
    │       │   ├── config.json
    │       │   ├── configuration_chatglm.py
    │       │   ├── modeling_chatglm.py
    │       │   ├── quantization.py
    │       │   ├── tokenization_chatglm.py
    │       │   └── tokenizer_config.json
    │       ├── glm-10B-chinese/
    │       │   ├── README.md
    │       │   ├── added_tokens.json
    │       │   ├── config.json
    │       │   ├── configuration_glm.py
    │       │   ├── modeling_glm.py
    │       │   ├── tokenization_glm.py
    │       │   └── tokenizer_config.json
    │       ├── glm-350M-chinese/
    │       │   ├── README.md
    │       │   ├── added_tokens.json
    │       │   ├── config.json
    │       │   ├── configuration_glm.py
    │       │   ├── modeling_glm.py
    │       │   ├── tokenization_glm.py
    │       │   └── tokenizer_config.json
    │       ├── pangu-13B/
    │       │   ├── config.json
    │       │   ├── configuration_gptpangu.py
    │       │   ├── modeling_gptpangu.py
    │       │   ├── tokenization_gptpangu.py
    │       │   └── tokenizer_config.json
    │       ├── pangu-2.6B/
    │       │   ├── config.json
    │       │   ├── configuration_gptpangu.py
    │       │   ├── modeling_gptpangu.py
    │       │   ├── tokenization_gptpangu.py
    │       │   └── tokenizer_config.json
    │       ├── pangu-350M/
    │       │   ├── README-Reward.md
    │       │   ├── README-SFT.md
    │       │   ├── README.md
    │       │   ├── config.json
    │       │   ├── configuration_gptpangu.py
    │       │   ├── modeling_gptpangu.py
    │       │   ├── tokenization_gptpangu.py
    │       │   └── tokenizer_config.json
    │       └── qwen-7B-chat/
    │           ├── config.json
    │           ├── configuration_qwen.py
    │           ├── generation_config.json
    │           ├── modeling_qwen.py
    │           ├── qwen_generation_utils.py
    │           ├── tokenization_qwen.py
    │           └── tokenizer_config.json
    ├── train_dpo.py
    ├── train_reward.py
    ├── train_rlhf.py
    ├── train_rlhf_trlx.py
    ├── train_sft.py
    ├── train_sft_wo_trainer.py
    └── utils/
        ├── __init__.py
        ├── config.py
        ├── file_utils.py
        ├── loading.py
        ├── logger.py
        ├── method_configs.py
        ├── modeling_utils.py
        └── nlp_utils.py
Download .txt
SYMBOL INDEX (1225 symbols across 64 files)

FILE: src/data/data.py
  function chatglm3_encode (line 21) | def chatglm3_encode(tokenizer: PreTrainedTokenizerBase,
  function chatglm2_encode (line 83) | def chatglm2_encode(tokenizer: PreTrainedTokenizerBase,
  class DataCollatorReward (line 147) | class DataCollatorReward:
    method __call__ (line 148) | def __call__(self, data):
  class DataCollatorRLHF (line 160) | class DataCollatorRLHF:
    method __init__ (line 162) | def __init__(self, max_token_len, inference_tp_size):
    method __call__ (line 166) | def __call__(self, data):
  class PretrainDataset (line 197) | class PretrainDataset(Dataset):
    method __init__ (line 198) | def __init__(self, args, filename, tokenizer, concat_samples=True):
    method __len__ (line 209) | def __len__(self):
    method __getitem__ (line 212) | def __getitem__(self, idx):
    method load_dataset (line 306) | def load_dataset(self, filename):
  class SFTDataset (line 351) | class SFTDataset(Dataset):
    method __init__ (line 352) | def __init__(self, args, filename, tokenizer, concat_samples=True):
    method __len__ (line 363) | def __len__(self):
    method __getitem__ (line 366) | def __getitem__(self, idx):
    method load_dataset (line 509) | def load_dataset(self, filename):
  class PairwiseDataset (line 563) | class PairwiseDataset(Dataset):
    method __init__ (line 564) | def __init__(self, args, filename, tokenizer):
    method __len__ (line 572) | def __len__(self):
    method __getitem__ (line 575) | def __getitem__(self, idx):
    method load_dataset (line 683) | def load_dataset(filename):
  class RLHFDataset (line 721) | class RLHFDataset(Dataset):
    method __init__ (line 722) | def __init__(self, args, filename, tokenizer):
    method __len__ (line 731) | def __len__(self):
    method __getitem__ (line 734) | def __getitem__(self, idx):
    method load_dataset (line 781) | def load_dataset(filename):
  class PPODataset (line 804) | class PPODataset:
    method __init__ (line 805) | def __init__(self, max_size, small_batch_size):
    method separate (line 810) | def separate(self):
    method add (line 834) | def add(self, data):
    method free (line 846) | def free(self):
  class DPODataset (line 850) | class DPODataset(Dataset):
    method __init__ (line 851) | def __init__(self, args, filename, tokenizer):
    method __len__ (line 859) | def __len__(self):
    method __getitem__ (line 862) | def __getitem__(self, idx):
    method load_dataset (line 928) | def load_dataset(filename):
  class OCNLIDataset (line 969) | class OCNLIDataset(Dataset):
    method __init__ (line 970) | def __init__(self, args, eval_filename, tokenizer, train_filename=None):
    method __len__ (line 983) | def __len__(self):
    method __getitem__ (line 986) | def __getitem__(self, idx):
    method load_dataset (line 1017) | def load_dataset(self, filename):
  class CMNLIDataset (line 1040) | class CMNLIDataset(Dataset):
    method __init__ (line 1041) | def __init__(self, args, eval_filename, tokenizer, train_filename=None):
    method __len__ (line 1054) | def __len__(self):
    method __getitem__ (line 1057) | def __getitem__(self, idx):
    method load_dataset (line 1090) | def load_dataset(self, filename):
  class CHIDDataset (line 1113) | class CHIDDataset(Dataset):
    method __init__ (line 1114) | def __init__(self, args, eval_filename, tokenizer, train_filename=None):
    method __len__ (line 1127) | def __len__(self):
    method __getitem__ (line 1130) | def __getitem__(self, idx):
    method load_dataset (line 1165) | def load_dataset(self, filename):
    method load_idiom_dict (line 1186) | def load_idiom_dict(self):
  class CMRCDataset (line 1195) | class CMRCDataset(Dataset):
    method __init__ (line 1196) | def __init__(self, args, eval_filename, tokenizer, train_filename=None):
    method __len__ (line 1208) | def __len__(self):
    method __getitem__ (line 1211) | def __getitem__(self, idx):
    method load_dataset (line 1242) | def load_dataset(self, filename):
  class CLUEWSCDataset (line 1267) | class CLUEWSCDataset(Dataset):
    method __init__ (line 1268) | def __init__(self, args, eval_filename, tokenizer, train_filename=None):
    method __len__ (line 1281) | def __len__(self):
    method __getitem__ (line 1284) | def __getitem__(self, idx):
    method load_dataset (line 1317) | def load_dataset(self, filename):
  class C3Dataset (line 1338) | class C3Dataset(Dataset):
    method __init__ (line 1339) | def __init__(self, args, eval_filename, tokenizer, train_filename=None):
    method __len__ (line 1351) | def __len__(self):
    method __getitem__ (line 1354) | def __getitem__(self, idx):
    method load_dataset (line 1389) | def load_dataset(self, filename):
  class AFQMCDataset (line 1412) | class AFQMCDataset(Dataset):
    method __init__ (line 1413) | def __init__(self, args, eval_filename, tokenizer, train_filename=None):
    method __len__ (line 1426) | def __len__(self):
    method __getitem__ (line 1429) | def __getitem__(self, idx):
    method load_dataset (line 1462) | def load_dataset(self, filename):
  class CSLDataset (line 1482) | class CSLDataset(Dataset):
    method __init__ (line 1483) | def __init__(self, args, eval_filename, tokenizer, train_filename=None):
    method __len__ (line 1496) | def __len__(self):
    method __getitem__ (line 1499) | def __getitem__(self, idx):
    method load_dataset (line 1532) | def load_dataset(self, filename):
  class IFLYTEKDataset (line 1552) | class IFLYTEKDataset(Dataset):
    method __init__ (line 1553) | def __init__(self, args, eval_filename, tokenizer, train_filename=None):
    method __len__ (line 1587) | def __len__(self):
    method __getitem__ (line 1590) | def __getitem__(self, idx):
    method load_dataset (line 1625) | def load_dataset(self, filename):
  class TNEWSDataset (line 1648) | class TNEWSDataset(Dataset):
    method __init__ (line 1649) | def __init__(self, args, eval_filename, tokenizer, train_filename=None):
    method __len__ (line 1676) | def __len__(self):
    method __getitem__ (line 1679) | def __getitem__(self, idx):
    method load_dataset (line 1714) | def load_dataset(self, filename):
  class CEvalDataset (line 1737) | class CEvalDataset(Dataset):
    method __init__ (line 1738) | def __init__(self, args, eval_filename, tokenizer, train_filename=None):
    method __len__ (line 1754) | def __len__(self):
    method format_example (line 1757) | def format_example(self, line, include_answer=True, cot=False):
    method __getitem__ (line 1785) | def __getitem__(self, idx):
    method load_dataset (line 1863) | def load_dataset(self, filename, return_format="list"):
  class MMLUDataset (line 1887) | class MMLUDataset(Dataset):
    method __init__ (line 1888) | def __init__(self, args, eval_filename, tokenizer, train_filename=None):
    method __len__ (line 1904) | def __len__(self):
    method format_example (line 1907) | def format_example(self, line, include_answer=True):
    method __getitem__ (line 1926) | def __getitem__(self, idx):
    method load_dataset (line 2003) | def load_dataset(self, filename, return_format="list"):

FILE: src/data/data_types.py
  class PromptElement (line 9) | class PromptElement:
  class PromptBatch (line 25) | class PromptBatch:
  class AccelerateRLElement (line 41) | class AccelerateRLElement:
  class AccelerateRLBatchElement (line 57) | class AccelerateRLBatchElement:
  class PPORLElement (line 73) | class PPORLElement:
  class PPORLBatch (line 107) | class PPORLBatch:

FILE: src/data/pipeline.py
  class GeneralElement (line 25) | class GeneralElement:
  class RLElement (line 34) | class RLElement:
  class BatchElement (line 45) | class BatchElement:
  class GLMDataCollator (line 55) | class GLMDataCollator:
    method __call__ (line 63) | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
  function register_datapipeline (line 102) | def register_datapipeline(name):
  class BasePipeline (line 125) | class BasePipeline(Dataset):
    method __init__ (line 126) | def __init__(self, path: str = "dataset"):
    method __getitem__ (line 130) | def __getitem__(self, index: int) -> GeneralElement:
    method __len__ (line 134) | def __len__(self) -> int:
    method create_loader (line 138) | def create_loader(
  class BaseRolloutStore (line 153) | class BaseRolloutStore(Dataset):
    method __init__ (line 154) | def __init__(self, capacity=-1):
    method push (line 159) | def push(self, exps: Iterable[Any]):
    method __getitem__ (line 165) | def __getitem__(self, index: int) -> RLElement:
    method __len__ (line 168) | def __len__(self) -> int:
    method create_loader (line 172) | def create_loader(
  class PanguPipeline (line 189) | class PanguPipeline(BasePipeline):
    method __init__ (line 190) | def __init__(self, prompts: List[dict], config: TRLConfig, tokenizer: ...
    method __len__ (line 199) | def __len__(self):
    method __getitem__ (line 202) | def __getitem__(self, idx):
    method create_loader (line 219) | def create_loader(self, batch_size: int, shuffle=False) -> DataLoader:
  class GLMPipeline (line 224) | class GLMPipeline(BasePipeline):
    method __init__ (line 225) | def __init__(self, prompts: List[dict], config: TRLConfig, tokenizer: ...
    method __len__ (line 235) | def __len__(self):
    method __getitem__ (line 238) | def __getitem__(self, idx):
    method create_loader (line 257) | def create_loader(self, batch_size: int, shuffle=False) -> DataLoader:
  class ChatGLMPipeline (line 263) | class ChatGLMPipeline(BasePipeline):
    method __init__ (line 264) | def __init__(self, prompts: List[dict], config: TRLConfig, tokenizer: ...
    method __len__ (line 273) | def __len__(self):
    method __getitem__ (line 276) | def __getitem__(self, idx):
    method create_loader (line 287) | def create_loader(self, batch_size: int, shuffle=False) -> DataLoader:
  class PPORolloutStorage (line 291) | class PPORolloutStorage(BaseRolloutStore):
    method __init__ (line 296) | def __init__(self, pad_token_id):
    method push (line 302) | def push(self, exps: Iterable[PPORLElement]):
    method clear_history (line 305) | def clear_history(self):
    method export_history (line 308) | def export_history(self, location: str):
    method __getitem__ (line 320) | def __getitem__(self, index: int) -> PPORLElement:
    method __len__ (line 323) | def __len__(self) -> int:
    method create_loader (line 326) | def create_loader(

FILE: src/data_prepare.py
  function weibo_summary_comment (line 24) | def weibo_summary_comment(args, tokenizer):
  function couplets (line 61) | def couplets(args, tokenizer):
  function zhidao (line 117) | def zhidao(args, tokenizer):
  function chinese_classical (line 159) | def chinese_classical(args, tokenizer):
  function chinese_poetry (line 221) | def chinese_poetry(args, tokenizer):
  function baike_qa_2019 (line 372) | def baike_qa_2019(args, tokenizer):
  function get_parser (line 404) | def get_parser():
  function main (line 418) | def main():

FILE: src/eval_pretrain.py
  function get_parser (line 60) | def get_parser():
  function extract_cot_answer (line 101) | def extract_cot_answer(line, response):
  function main (line 106) | def main():

FILE: src/models/loss.py
  class PairWiseLoss (line 6) | class PairWiseLoss(nn.Module):
    method forward (line 11) | def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Te...

FILE: src/models/ppo.py
  class PreTrainedModelWrapper (line 51) | class PreTrainedModelWrapper(nn.Module, transformers.utils.PushToHubMixin):
    method __init__ (line 76) | def __init__(self, base_model: Optional[transformers.PreTrainedModel] ...
    method _split_kwargs (line 83) | def _split_kwargs(cls, kwargs: Dict[str, Any]):
    method from_config (line 97) | def from_config(cls, config: transformers.PretrainedConfig, **kwargs):
    method from_pretrained (line 118) | def from_pretrained(  # noqa: max-complexity
    method save_pretrained (line 210) | def save_pretrained(self, *args, **kwargs):
    method state_dict (line 231) | def state_dict(self, *args, **kwargs):
    method post_init (line 235) | def post_init(self, *args, **kwargs):
    method get_compatible_forward_kwargs (line 242) | def get_compatible_forward_kwargs(self, **kwargs) -> Dict[str, Any]:
  class AdaptiveKLController (line 252) | class AdaptiveKLController:
    method __init__ (line 258) | def __init__(self, init_kl_coef: float, target: float, horizon: int):
    method update (line 263) | def update(self, current: float, n_steps: int):
  class FixedKLController (line 273) | class FixedKLController:
    method __init__ (line 276) | def __init__(self, kl_coef):
    method update (line 279) | def update(self, current: float, n_steps: int):
  class CausalLMOutputWithValue (line 289) | class CausalLMOutputWithValue(ModelOutput):
  class AutoModelForCausalLMWithValueHead (line 299) | class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
    method __init__ (line 308) | def __init__(
    method forward (line 316) | def forward(
    method generate (line 353) | def generate(self, *args, **kwargs) -> Union[ModelOutput, torch.LongTe...
    method state_dict (line 356) | def state_dict(self, *args, **kwargs):
    method post_init (line 367) | def post_init(self, state_dict):
  class AutoModelForCausalLMWithHydraValueHead (line 381) | class AutoModelForCausalLMWithHydraValueHead(AutoModelForCausalLMWithVal...
    method __init__ (line 385) | def __init__(
    method forward_hydra (line 414) | def forward_hydra(
    method from_pretrained (line 457) | def from_pretrained(  # noqa: max-complexity
  class ModelBranch (line 564) | class ModelBranch(transformers.PreTrainedModel):
    method __init__ (line 569) | def __init__(
  class GPTModelBranch (line 599) | class GPTModelBranch(ModelBranch):
    method forward (line 600) | def forward(  # noqa: max-complexity
  class OPTModelBranch (line 733) | class OPTModelBranch(ModelBranch):
    method forward (line 734) | def forward(  # noqa: max-complexity
  class BloomModelBranch (line 853) | class BloomModelBranch(ModelBranch):
    method forward (line 854) | def forward(  # noqa: max-complexity
  class Seq2SeqLMOutputWithValue (line 971) | class Seq2SeqLMOutputWithValue(ModelOutput):
  class AutoModelForSeq2SeqLMWithValueHead (line 984) | class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper):
    method __init__ (line 993) | def __init__(
    method forward (line 1001) | def forward(
    method generate (line 1053) | def generate(self, *args, **kwargs) -> Union[ModelOutput, torch.LongTe...
    method state_dict (line 1056) | def state_dict(self, *args, **kwargs):
    method post_init (line 1067) | def post_init(self, state_dict):
  class AutoModelForSeq2SeqLMWithHydraValueHead (line 1081) | class AutoModelForSeq2SeqLMWithHydraValueHead(AutoModelForSeq2SeqLMWithV...
    method __init__ (line 1085) | def __init__(
    method forward_hydra (line 1103) | def forward_hydra(
    method from_pretrained (line 1161) | def from_pretrained(  # noqa: max-complexity
  class T5Branch (line 1268) | class T5Branch(ModelBranch):
    method __init__ (line 1271) | def __init__(
    method forward (line 1281) | def forward(  # noqa: max-complexity
  function hf_get_branch_class (line 1381) | def hf_get_branch_class(

FILE: src/models/reward.py
  class RewardModel (line 7) | class RewardModel(PreTrainedModel):
    method __init__ (line 10) | def __init__(self, config, model, tokenizer):
    method gradient_checkpointing_enable (line 24) | def gradient_checkpointing_enable(self):
    method gradient_checkpointing_disable (line 27) | def gradient_checkpointing_disable(self):
    method _set_gradient_checkpointing (line 30) | def _set_gradient_checkpointing(self, module, value=False):
    method reward (line 34) | def reward(
    method forward (line 71) | def forward(

FILE: src/models/rlhf_engine.py
  function log_init (line 37) | def log_init(model_name, rank, stime=None):
  function create_hf_model (line 49) | def create_hf_model(model_class,
  function create_critic_model (line 93) | def create_critic_model(model_name_or_path,
  class DeepSpeedRLHFEngine (line 130) | class DeepSpeedRLHFEngine:
    method __init__ (line 132) | def __init__(self, actor_model_name_or_path, critic_model_name_or_path,
    method _init_actor (line 164) | def _init_actor(self, actor_model_name_or_path):
    method _init_ref (line 232) | def _init_ref(self, actor_model_name_or_path):
    method _init_ema (line 255) | def _init_ema(self, actor_model_name_or_path):
    method _init_critic (line 291) | def _init_critic(self, critic_model_name_or_path):
    method _init_reward (line 362) | def _init_reward(self, critic_model_name_or_path):

FILE: src/models/trainer.py
  function register_trainer (line 75) | def register_trainer(name):
  class BaseRLTrainer (line 98) | class BaseRLTrainer:
    method __init__ (line 99) | def __init__(
    method push_to_store (line 116) | def push_to_store(self, data):
    method add_eval_pipeline (line 119) | def add_eval_pipeline(self, eval_pipeline):
    method sample (line 124) | def sample(self, prompts: Iterable[str], length: int, n_samples: int) ...
    method learn (line 138) | def learn(
    method save (line 159) | def save(self, directory: Optional[str] = None):
    method load (line 164) | def load(self, directory=None):
  class AccelerateRLTrainer (line 170) | class AccelerateRLTrainer(BaseRLTrainer):
    method __init__ (line 175) | def __init__(self, config, **kwargs):  # noqa: C901
    method setup_model (line 252) | def setup_model(self):
    method setup_optimizer (line 278) | def setup_optimizer(self):
    method setup_scheduler (line 300) | def setup_scheduler(self):
    method decode (line 308) | def decode(
    method generate (line 374) | def generate(self, input_ids, attention_mask=None, **kwargs):
    method generate_eval (line 389) | def generate_eval(self, input_ids, attention_mask=None, **kwargs):
    method save_pretrained (line 402) | def save_pretrained(self, directory: Optional[str] = None, **kwargs):
    method save (line 420) | def save(self, directory: Optional[str] = None, **kwargs):
    method load (line 424) | def load(self, directory: Optional[str] = None, **kwargs):
    method add_eval_pipeline (line 428) | def add_eval_pipeline(self, eval_pipeline):
    method evaluate (line 432) | def evaluate(self):  # noqa: C901
    method learn (line 579) | def learn(self):  # noqa: C901
    method get_arch (line 703) | def get_arch(self, config: TRLConfig):
    method loss (line 708) | def loss(self, batch) -> Tuple[float, Dict]:
    method post_backward_callback (line 713) | def post_backward_callback(self):
    method post_epoch_callback (line 718) | def post_epoch_callback(self):
  class AcceleratePPOTrainer (line 724) | class AcceleratePPOTrainer(AccelerateRLTrainer):
    method __init__ (line 730) | def __init__(self, config: TRLConfig, **kwargs):
    method get_arch (line 810) | def get_arch(self, config: TRLConfig):
    method loss (line 830) | def loss(self, batch: PPORLBatch):
    method setup_rollout_logging (line 913) | def setup_rollout_logging(self, config):
    method post_epoch_callback (line 926) | def post_epoch_callback(self):
    method post_backward_callback (line 937) | def post_backward_callback(self):
    method prepare_learning (line 940) | def prepare_learning(self):
    method add_prompt_pipeline (line 949) | def add_prompt_pipeline(self, pipeline: BasePipeline):
    method make_experience (line 955) | def make_experience(self, num_rollouts: int = 1024, iter_count: int = ...
  function get_model_norm (line 1259) | def get_model_norm(model):
  function gather_log_probs (line 1273) | def gather_log_probs(logits, labels):
  class DeepSpeedPPOTrainer (line 1279) | class DeepSpeedPPOTrainer():
    method __init__ (line 1281) | def __init__(self, rlhf_engine, args):
    method generate_sequence (line 1302) | def generate_sequence(self, inputs):
    method generate_experience (line 1416) | def generate_experience(self, output_sequences, answer_start_indices):
    method compute_rewards (line 1455) | def compute_rewards(self, starts, log_probs, ref_log_probs, reward_sco...
    method train_rlhf (line 1482) | def train_rlhf(self, inputs):
    method actor_loss_fn (line 1563) | def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
    method critic_loss_fn (line 1573) | def critic_loss_fn(self, values, old_values, returns, mask):
    method get_advantages_and_returns (line 1588) | def get_advantages_and_returns(self, values, rewards, starts):
    method _validate_training_mode (line 1632) | def _validate_training_mode(self):
    method _validate_evaluation_mode (line 1637) | def _validate_evaluation_mode(self):
    method train (line 1644) | def train(self):
    method eval (line 1649) | def eval(self):
    method dump_model_norms (line 1656) | def dump_model_norms(self, tag):
  class DeepSpeedPPOPTXTrainer (line 1674) | class DeepSpeedPPOPTXTrainer(DeepSpeedPPOTrainer):
    method __init__ (line 1676) | def __init__(self, *args, **kwargs):
    method train_unsupervised (line 1679) | def train_unsupervised(self, inputs, unsup_coef):
  class DPOTrainer (line 1691) | class DPOTrainer(Trainer):
    method __init__ (line 1738) | def __init__(
    method concatenated_inputs (line 1863) | def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongT...
    method dpo_loss (line 1892) | def dpo_loss(
    method _get_batch_logps (line 1929) | def _get_batch_logps(
    method concatenated_forward (line 1962) | def concatenated_forward(
    method separate_forward (line 1986) | def separate_forward(
    method get_batch_metrics (line 2012) | def get_batch_metrics(
    method compute_loss (line 2077) | def compute_loss(
    method get_batch_samples (line 2093) | def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor])...
    method prediction_step (line 2130) | def prediction_step(
    method store_metrics (line 2164) | def store_metrics(self, metrics: Dict[str, float], train_eval: Literal...
    method log (line 2168) | def log(self, logs: Dict[str, float]) -> None:

FILE: src/pretrain.py
  function preprocess_logits_for_metrics (line 28) | def preprocess_logits_for_metrics(logits, labels):
  function get_parser (line 35) | def get_parser():
  function main (line 109) | def main():

FILE: src/pretrain_wo_trainer.py
  function preprocess_logits_for_metrics (line 30) | def preprocess_logits_for_metrics(logits, labels):
  function get_parser (line 37) | def get_parser():
  function pred_single_sample (line 110) | def pred_single_sample(prompt, prefix, model, tokenizer, args, device, e...
  function pred (line 180) | def pred(args, model, tokenizer, device, eos_token_id, step=-1):
  function main (line 196) | def main():

FILE: src/resources/models/baichuan-13B-base/configuration_baichuan.py
  class BaichuanConfig (line 5) | class BaichuanConfig(PretrainedConfig):
    method __init__ (line 9) | def __init__(

FILE: src/resources/models/baichuan-13B-base/modeling_baichuan.py
  function _get_interleave (line 20) | def _get_interleave(n):
  function _fill_with_neg_inf (line 34) | def _fill_with_neg_inf(t):
  function _gen_alibi_mask (line 39) | def _gen_alibi_mask(n_head, max_pos, alibi_mask=None):
  class RMSNorm (line 52) | class RMSNorm(torch.nn.Module):
    method __init__ (line 53) | def __init__(self, hidden_size, epsilon=1e-6):
    method forward (line 58) | def forward(self, hidden_states):
  class MLP (line 69) | class MLP(torch.nn.Module):
    method __init__ (line 70) | def __init__(
    method forward (line 82) | def forward(self, x):
  class BaichuanAttention (line 86) | class BaichuanAttention(torch.nn.Module):
    method __init__ (line 88) | def __init__(self, config: BaichuanConfig):
    method _shape (line 103) | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
    method forward (line 106) | def forward(
  class BaichuanLayer (line 155) | class BaichuanLayer(torch.nn.Module):
    method __init__ (line 156) | def __init__(self, config: BaichuanConfig):
    method forward (line 168) | def forward(
  class BaichuanPreTrainedModel (line 205) | class BaichuanPreTrainedModel(PreTrainedModel):
    method _init_weights (line 212) | def _init_weights(self, module):
    method _set_gradient_checkpointing (line 223) | def _set_gradient_checkpointing(self, module, value=False):
  class BaichuanModel (line 228) | class BaichuanModel(BaichuanPreTrainedModel):
    method __init__ (line 229) | def __init__(self, config: BaichuanConfig):
    method get_input_embeddings (line 243) | def get_input_embeddings(self):
    method set_input_embeddings (line 246) | def set_input_embeddings(self, value):
    method get_alibi_mask (line 249) | def get_alibi_mask(self, tensor, seq_length_with_past, attention_mask):
    method forward (line 267) | def forward(
  class BaichuanForCausalLM (line 368) | class BaichuanForCausalLM(BaichuanPreTrainedModel):
    method __init__ (line 369) | def __init__(self, config):
    method forward (line 377) | def forward(
    method prepare_inputs_for_generation (line 430) | def prepare_inputs_for_generation(
    method _reorder_cache (line 451) | def _reorder_cache(past_key_values, beam_idx):
    method quantize (line 457) | def quantize(self, bits: int):
    method _build_chat_input (line 493) | def _build_chat_input(self, tokenizer, messages: List[dict], max_new_t...
    method chat (line 524) | def chat(self, tokenizer, messages: List[dict], stream=False,

FILE: src/resources/models/baichuan-13B-base/tokenization_baichuan.py
  class BaichuanTokenizer (line 23) | class BaichuanTokenizer(PreTrainedTokenizer):
    method __init__ (line 37) | def __init__(
    method __getstate__ (line 72) | def __getstate__(self):
    method __setstate__ (line 77) | def __setstate__(self, d):
    method vocab_size (line 83) | def vocab_size(self):
    method get_vocab (line 87) | def get_vocab(self):
    method _tokenize (line 93) | def _tokenize(self, text):
    method _convert_token_to_id (line 97) | def _convert_token_to_id(self, token):
    method _convert_id_to_token (line 101) | def _convert_id_to_token(self, index):
    method convert_tokens_to_string (line 106) | def convert_tokens_to_string(self, tokens):
    method save_vocabulary (line 125) | def save_vocabulary(self, save_directory, filename_prefix: Optional[st...
    method build_inputs_with_special_tokens (line 152) | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=No...
    method get_special_tokens_mask (line 163) | def get_special_tokens_mask(
    method create_token_type_ids_from_sequences (line 200) | def create_token_type_ids_from_sequences(

FILE: src/resources/models/baichuan-13B-chat/configuration_baichuan.py
  class BaichuanConfig (line 5) | class BaichuanConfig(PretrainedConfig):
    method __init__ (line 9) | def __init__(

FILE: src/resources/models/baichuan-13B-chat/modeling_baichuan.py
  function _get_interleave (line 20) | def _get_interleave(n):
  function _fill_with_neg_inf (line 34) | def _fill_with_neg_inf(t):
  function _gen_alibi_mask (line 39) | def _gen_alibi_mask(n_head, max_pos, alibi_mask=None):
  class RMSNorm (line 52) | class RMSNorm(torch.nn.Module):
    method __init__ (line 53) | def __init__(self, hidden_size, epsilon=1e-6):
    method forward (line 58) | def forward(self, hidden_states):
  class MLP (line 69) | class MLP(torch.nn.Module):
    method __init__ (line 70) | def __init__(
    method forward (line 82) | def forward(self, x):
  class BaichuanAttention (line 86) | class BaichuanAttention(torch.nn.Module):
    method __init__ (line 88) | def __init__(self, config: BaichuanConfig):
    method _shape (line 103) | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
    method forward (line 106) | def forward(
  class BaichuanLayer (line 155) | class BaichuanLayer(torch.nn.Module):
    method __init__ (line 156) | def __init__(self, config: BaichuanConfig):
    method forward (line 168) | def forward(
  class BaichuanPreTrainedModel (line 205) | class BaichuanPreTrainedModel(PreTrainedModel):
    method _init_weights (line 212) | def _init_weights(self, module):
    method _set_gradient_checkpointing (line 223) | def _set_gradient_checkpointing(self, module, value=False):
  class BaichuanModel (line 228) | class BaichuanModel(BaichuanPreTrainedModel):
    method __init__ (line 229) | def __init__(self, config: BaichuanConfig):
    method get_input_embeddings (line 243) | def get_input_embeddings(self):
    method set_input_embeddings (line 246) | def set_input_embeddings(self, value):
    method get_alibi_mask (line 249) | def get_alibi_mask(self, tensor, seq_length_with_past, attention_mask):
    method forward (line 267) | def forward(
  class BaichuanForCausalLM (line 368) | class BaichuanForCausalLM(BaichuanPreTrainedModel):
    method __init__ (line 369) | def __init__(self, config):
    method forward (line 377) | def forward(
    method prepare_inputs_for_generation (line 430) | def prepare_inputs_for_generation(
    method _reorder_cache (line 451) | def _reorder_cache(past_key_values, beam_idx):
    method quantize (line 458) | def quantize(self, bits: int):
    method _build_chat_input (line 494) | def _build_chat_input(self, tokenizer, messages: List[dict], max_new_t...
    method chat (line 525) | def chat(self, tokenizer, messages: List[dict], stream=False,

FILE: src/resources/models/baichuan-13B-chat/tokenization_baichuan.py
  class BaichuanTokenizer (line 23) | class BaichuanTokenizer(PreTrainedTokenizer):
    method __init__ (line 37) | def __init__(
    method __getstate__ (line 72) | def __getstate__(self):
    method __setstate__ (line 77) | def __setstate__(self, d):
    method vocab_size (line 83) | def vocab_size(self):
    method get_vocab (line 87) | def get_vocab(self):
    method _tokenize (line 93) | def _tokenize(self, text):
    method _convert_token_to_id (line 97) | def _convert_token_to_id(self, token):
    method _convert_id_to_token (line 101) | def _convert_id_to_token(self, index):
    method convert_tokens_to_string (line 106) | def convert_tokens_to_string(self, tokens):
    method save_vocabulary (line 125) | def save_vocabulary(self, save_directory, filename_prefix: Optional[st...
    method build_inputs_with_special_tokens (line 152) | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=No...
    method get_special_tokens_mask (line 163) | def get_special_tokens_mask(
    method create_token_type_ids_from_sequences (line 200) | def create_token_type_ids_from_sequences(

FILE: src/resources/models/baichuan-7B/configuration_baichuan.py
  class BaiChuanConfig (line 28) | class BaiChuanConfig(PretrainedConfig):
    method __init__ (line 32) | def __init__(

FILE: src/resources/models/baichuan-7B/modeling_baichuan.py
  function _make_causal_mask (line 39) | def _make_causal_mask(
  function _expand_mask (line 57) | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Option...
  class RMSNorm (line 71) | class RMSNorm(nn.Module):
    method __init__ (line 72) | def __init__(self, hidden_size, eps=1e-6):
    method forward (line 80) | def forward(self, hidden_states):
  class RotaryEmbedding (line 91) | class RotaryEmbedding(torch.nn.Module):
    method __init__ (line 92) | def __init__(self, dim, max_position_embeddings=2048, base=10000, devi...
    method forward (line 106) | def forward(self, x, seq_len=None):
  function rotate_half (line 123) | def rotate_half(x):
  function apply_rotary_pos_emb (line 130) | def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
  class MLP (line 141) | class MLP(nn.Module):
    method __init__ (line 142) | def __init__(
    method forward (line 154) | def forward(self, x):
  class Attention (line 158) | class Attention(nn.Module):
    method __init__ (line 161) | def __init__(self, config: BaiChuanConfig):
    method _shape (line 181) | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
    method forward (line 184) | def forward(
  class DecoderLayer (line 259) | class DecoderLayer(nn.Module):
    method __init__ (line 260) | def __init__(self, config: BaiChuanConfig):
    method forward (line 272) | def forward(
  class PreTrainedModel (line 327) | class PreTrainedModel(PreTrainedModel):
    method _init_weights (line 334) | def _init_weights(self, module):
    method _set_gradient_checkpointing (line 345) | def _set_gradient_checkpointing(self, module, value=False):
  class Model (line 350) | class Model(PreTrainedModel):
    method __init__ (line 358) | def __init__(self, config: BaiChuanConfig):
    method get_input_embeddings (line 371) | def get_input_embeddings(self):
    method set_input_embeddings (line 374) | def set_input_embeddings(self, value):
    method _prepare_decoder_attention_mask (line 378) | def _prepare_decoder_attention_mask(self, attention_mask, input_shape,...
    method forward (line 401) | def forward(
  class BaiChuanForCausalLM (line 529) | class BaiChuanForCausalLM(PreTrainedModel):
    method __init__ (line 530) | def __init__(self, config):
    method get_input_embeddings (line 539) | def get_input_embeddings(self):
    method set_input_embeddings (line 542) | def set_input_embeddings(self, value):
    method get_output_embeddings (line 545) | def get_output_embeddings(self):
    method set_output_embeddings (line 548) | def set_output_embeddings(self, new_embeddings):
    method set_decoder (line 551) | def set_decoder(self, decoder):
    method get_decoder (line 554) | def get_decoder(self):
    method forward (line 557) | def forward(
    method prepare_inputs_for_generation (line 643) | def prepare_inputs_for_generation(
    method _reorder_cache (line 674) | def _reorder_cache(past_key_values, beam_idx):

FILE: src/resources/models/baichuan-7B/tokenization_baichuan.py
  class BaiChuanTokenizer (line 42) | class BaiChuanTokenizer(PreTrainedTokenizer):
    method __init__ (line 56) | def __init__(
    method __getstate__ (line 91) | def __getstate__(self):
    method __setstate__ (line 96) | def __setstate__(self, d):
    method vocab_size (line 102) | def vocab_size(self):
    method get_vocab (line 106) | def get_vocab(self):
    method _tokenize (line 112) | def _tokenize(self, text):
    method _convert_token_to_id (line 116) | def _convert_token_to_id(self, token):
    method _convert_id_to_token (line 120) | def _convert_id_to_token(self, index):
    method convert_tokens_to_string (line 125) | def convert_tokens_to_string(self, tokens):
    method save_vocabulary (line 144) | def save_vocabulary(self, save_directory, filename_prefix: Optional[st...
    method build_inputs_with_special_tokens (line 171) | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=No...
    method get_special_tokens_mask (line 182) | def get_special_tokens_mask(
    method create_token_type_ids_from_sequences (line 219) | def create_token_type_ids_from_sequences(

FILE: src/resources/models/chatglm-6B/configuration_chatglm.py
  class ChatGLMConfig (line 9) | class ChatGLMConfig(PretrainedConfig):
    method __init__ (line 59) | def __init__(

FILE: src/resources/models/chatglm-6B/modeling_chatglm.py
  class InvalidScoreLogitsProcessor (line 54) | class InvalidScoreLogitsProcessor(LogitsProcessor):
    method __call__ (line 55) | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTen...
  function load_tf_weights_in_chatglm_6b (line 62) | def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
  class PrefixEncoder (line 136) | class PrefixEncoder(torch.nn.Module):
    method __init__ (line 143) | def __init__(self, config):
    method forward (line 157) | def forward(self, prefix: torch.Tensor):
  function gelu_impl (line 167) | def gelu_impl(x):
  function gelu (line 173) | def gelu(x):
  class RotaryEmbedding (line 177) | class RotaryEmbedding(torch.nn.Module):
    method __init__ (line 178) | def __init__(self, dim, base=10000, precision=torch.half, learnable=Fa...
    method _load_from_state_dict (line 193) | def _load_from_state_dict(self, state_dict, prefix, local_metadata, st...
    method forward (line 197) | def forward(self, x, seq_dim=1, seq_len=None):
    method _apply (line 220) | def _apply(self, fn):
  function rotate_half (line 228) | def rotate_half(x):
  function apply_rotary_pos_emb_index (line 234) | def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
  function attention_fn (line 242) | def attention_fn(
  class SelfAttention (line 351) | class SelfAttention(torch.nn.Module):
    method __init__ (line 352) | def __init__(self, hidden_size, num_attention_heads,
    method attention_mask_func (line 399) | def attention_mask_func(attention_scores, attention_mask):
    method split_tensor_along_last_dim (line 403) | def split_tensor_along_last_dim(self, tensor, num_partitions,
    method forward (line 423) | def forward(
  class GEGLU (line 490) | class GEGLU(torch.nn.Module):
    method __init__ (line 491) | def __init__(self):
    method forward (line 495) | def forward(self, x):
  class GLU (line 501) | class GLU(torch.nn.Module):
    method __init__ (line 502) | def __init__(self, hidden_size, inner_hidden_size=None,
    method forward (line 529) | def forward(self, hidden_states):
  class GLMBlock (line 544) | class GLMBlock(torch.nn.Module):
    method __init__ (line 545) | def __init__(
    method forward (line 594) | def forward(
  class ChatGLMPreTrainedModel (line 648) | class ChatGLMPreTrainedModel(PreTrainedModel):
    method __init__ (line 660) | def __init__(self, *inputs, **kwargs):
    method _init_weights (line 663) | def _init_weights(self, module: nn.Module):
    method _set_gradient_checkpointing (line 667) | def _set_gradient_checkpointing(self, module, value=False):
  class ChatGLMModel (line 737) | class ChatGLMModel(ChatGLMPreTrainedModel):
    method __init__ (line 753) | def __init__(self, config: ChatGLMConfig):
    method get_input_embeddings (line 809) | def get_input_embeddings(self):
    method set_input_embeddings (line 812) | def set_input_embeddings(self, new_embeddings: torch.Tensor):
    method get_prompt (line 815) | def get_prompt(self, batch_size, device, dtype=torch.half):
    method get_masks (line 831) | def get_masks(self, input_ids, device):
    method get_position_ids (line 843) | def get_position_ids(self, input_ids, mask_positions, device, gmask=Fa...
    method forward (line 871) | def forward(
  class ChatGLMForConditionalGeneration (line 1008) | class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
    method __init__ (line 1009) | def __init__(self, config: ChatGLMConfig):
    method get_output_embeddings (line 1036) | def get_output_embeddings(self):
    method set_output_embeddings (line 1039) | def set_output_embeddings(self, new_embeddings):
    method get_masks_and_position_ids (line 1042) | def get_masks_and_position_ids(self, input_ids, mask_positions, device...
    method prepare_inputs_for_generation (line 1073) | def prepare_inputs_for_generation(
    method forward (line 1122) | def forward(
    method _reorder_cache (line 1181) | def _reorder_cache(
    method process_response (line 1199) | def process_response(self, response):
    method chat (line 1215) | def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] =...
    method stream_chat (line 1241) | def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, ...
    method stream_generate (line 1267) | def stream_generate(
    method quantize (line 1368) | def quantize(self, bits: int, empty_init=False, **kwargs):

FILE: src/resources/models/chatglm-6B/quantization.py
  class Kernel (line 18) | class Kernel:
    method __init__ (line 19) | def __init__(self, code: bytes, function_names: List[str]):
  class W8A16Linear (line 44) | class W8A16Linear(torch.autograd.Function):
    method forward (line 46) | def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: to...
    method backward (line 58) | def backward(ctx, grad_output: torch.Tensor):
  function compress_int4_weight (line 67) | def compress_int4_weight(weight: torch.Tensor):  # (n, m)
  function extract_weight_to_half (line 88) | def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tenso...
  class QuantizedLinear (line 120) | class QuantizedLinear(Linear):
    method __init__ (line 121) | def __init__(self, weight_bit_width: int, weight_tensor=None, bias_ten...
    method forward (line 146) | def forward(self, input):
  function quantize (line 153) | def quantize(model, weight_bit_width, empty_init=False, **kwargs):

FILE: src/resources/models/chatglm-6B/tokenization_chatglm.py
  class SPTokenizer (line 23) | class SPTokenizer:
    method __init__ (line 24) | def __init__(
    method _configure_tokenizer (line 39) | def _configure_tokenizer(
    method _build_text_tokenizer (line 66) | def _build_text_tokenizer(self, encode_special_tokens=False):
    method _get_text_tokenizer (line 73) | def _get_text_tokenizer(self, encode_special_tokens=False):
    method get_blank_token (line 80) | def get_blank_token(length: int):
    method get_tab_token (line 85) | def get_tab_token():
    method num_image_tokens (line 89) | def num_image_tokens(self):
    method num_text_tokens (line 93) | def num_text_tokens(self):
    method num_tokens (line 97) | def num_tokens(self):
    method _encode_whitespaces (line 101) | def _encode_whitespaces(text: str, max_len: int = 80):
    method _preprocess (line 107) | def _preprocess(self, text: str, linebreak=True, whitespaces=True):
    method encode (line 114) | def encode(
    method decode (line 131) | def decode(self, text_ids: List[int], special_tokens=False) -> str:
    method tokenize (line 145) | def tokenize(
    method __getitem__ (line 161) | def __getitem__(self, x: Union[int, str]):
  class ChatGLMTokenizer (line 176) | class ChatGLMTokenizer(PreTrainedTokenizer):
    method __init__ (line 189) | def __init__(
    method eop_token_id (line 224) | def eop_token_id(self) -> Optional[int]:
    method gmask_token_id (line 234) | def gmask_token_id(self) -> Optional[int]:
    method vocab_size (line 244) | def vocab_size(self):
    method get_vocab (line 248) | def get_vocab(self):
    method preprocess_text (line 254) | def preprocess_text(self, inputs):
    method _tokenize (line 265) | def _tokenize(self, text, **kwargs):
    method decode (line 273) | def decode(
    method _convert_token_to_id (line 293) | def _convert_token_to_id(self, token):
    method _convert_id_to_token (line 297) | def _convert_id_to_token(self, index):
    method save_vocabulary (line 301) | def save_vocabulary(self, save_directory, filename_prefix=None):
    method build_inputs_with_special_tokens (line 353) | def build_inputs_with_special_tokens(

FILE: src/resources/models/chatglm2-6B/configuration_chatglm.py
  class ChatGLMConfig (line 4) | class ChatGLMConfig(PretrainedConfig):
    method __init__ (line 5) | def __init__(

FILE: src/resources/models/chatglm2-6B/modeling_chatglm.py
  function default_init (line 47) | def default_init(cls, *args, **kwargs):
  class InvalidScoreLogitsProcessor (line 51) | class InvalidScoreLogitsProcessor(LogitsProcessor):
    method __call__ (line 52) | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTen...
  function split_tensor_along_last_dim (line 59) | def split_tensor_along_last_dim(
  class RotaryEmbedding (line 87) | class RotaryEmbedding(nn.Module):
    method __init__ (line 88) | def __init__(self, dim, original_impl=False, device=None, dtype=None):
    method forward_impl (line 95) | def forward_impl(
    method forward (line 120) | def forward(self, max_seq_len, offset=0):
  function apply_rotary_pos_emb (line 127) | def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> t...
  class RMSNorm (line 147) | class RMSNorm(torch.nn.Module):
    method __init__ (line 148) | def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None...
    method forward (line 153) | def forward(self, hidden_states: torch.Tensor):
  class CoreAttention (line 161) | class CoreAttention(torch.nn.Module):
    method __init__ (line 162) | def __init__(self, config: ChatGLMConfig, layer_number):
    method forward (line 187) | def forward(self, query_layer, key_layer, value_layer, attention_mask):
  class SelfAttention (line 282) | class SelfAttention(torch.nn.Module):
    method __init__ (line 289) | def __init__(self, config: ChatGLMConfig, layer_number, device=None):
    method _allocate_memory (line 318) | def _allocate_memory(self, inference_max_sequence_len, batch_size, dev...
    method forward (line 332) | def forward(
  function _config_to_kwargs (line 421) | def _config_to_kwargs(args):
  class MLP (line 428) | class MLP(torch.nn.Module):
    method __init__ (line 436) | def __init__(self, config: ChatGLMConfig, device=None):
    method forward (line 465) | def forward(self, hidden_states):
  class GLMBlock (line 474) | class GLMBlock(torch.nn.Module):
    method __init__ (line 481) | def __init__(self, config: ChatGLMConfig, layer_number, device=None):
    method forward (line 505) | def forward(
  class GLMTransformer (line 548) | class GLMTransformer(torch.nn.Module):
    method __init__ (line 551) | def __init__(self, config: ChatGLMConfig, device=None):
    method _get_layer (line 574) | def _get_layer(self, layer_number):
    method forward (line 577) | def forward(
  class ChatGLMPreTrainedModel (line 624) | class ChatGLMPreTrainedModel(PreTrainedModel):
    method _init_weights (line 636) | def _init_weights(self, module: nn.Module):
    method get_masks (line 640) | def get_masks(self, input_ids, past_key_values, padding_mask=None):
    method get_position_ids (line 658) | def get_position_ids(self, input_ids, device):
    method _set_gradient_checkpointing (line 663) | def _set_gradient_checkpointing(self, module, value=False):
  class Embedding (line 671) | class Embedding(torch.nn.Module):
    method __init__ (line 674) | def __init__(self, config: ChatGLMConfig, device=None):
    method forward (line 687) | def forward(self, input_ids):
  class ChatGLMModel (line 699) | class ChatGLMModel(ChatGLMPreTrainedModel):
    method __init__ (line 700) | def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
    method get_input_embeddings (line 724) | def get_input_embeddings(self):
    method forward (line 727) | def forward(
    method quantize (line 785) | def quantize(self, weight_bit_width: int):
  class ChatGLMForConditionalGeneration (line 791) | class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
    method __init__ (line 792) | def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
    method _update_model_kwargs_for_generation (line 803) | def _update_model_kwargs_for_generation(
    method forward (line 834) | def forward(
    method prepare_inputs_for_generation (line 896) | def prepare_inputs_for_generation(
    method _reorder_cache (line 920) | def _reorder_cache(
    method process_response (line 938) | def process_response(self, response):
    method build_inputs (line 943) | def build_inputs(self, tokenizer, query: str, history: List[Tuple[str,...
    method build_stream_inputs (line 952) | def build_stream_inputs(self, tokenizer, query: str, history: List[Tup...
    method chat (line 966) | def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] =...
    method stream_chat (line 984) | def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, ...
    method stream_generate (line 1018) | def stream_generate(
    method quantize (line 1122) | def quantize(self, bits: int, empty_init=False, device=None, **kwargs):

FILE: src/resources/models/chatglm2-6B/quantization.py
  class Kernel (line 18) | class Kernel:
    method __init__ (line 19) | def __init__(self, code: bytes, function_names: List[str]):
  class W8A16Linear (line 44) | class W8A16Linear(torch.autograd.Function):
    method forward (line 46) | def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: to...
    method backward (line 58) | def backward(ctx, grad_output: torch.Tensor):
  function compress_int4_weight (line 67) | def compress_int4_weight(weight: torch.Tensor):  # (n, m)
  function extract_weight_to_half (line 88) | def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tenso...
  class QuantizedLinear (line 124) | class QuantizedLinear(torch.nn.Module):
    method __init__ (line 125) | def __init__(self, weight_bit_width: int, weight, bias=None, device="c...
    method forward (line 145) | def forward(self, input):
  function quantize (line 152) | def quantize(model, weight_bit_width, empty_init=False, device=None):

FILE: src/resources/models/chatglm2-6B/tokenization_chatglm.py
  class SPTokenizer (line 10) | class SPTokenizer:
    method __init__ (line 11) | def __init__(self, model_path: str):
    method tokenize (line 31) | def tokenize(self, s: str):
    method encode (line 34) | def encode(self, s: str, bos: bool = False, eos: bool = False) -> List...
    method decode (line 43) | def decode(self, t: List[int]) -> str:
    method decode_tokens (line 46) | def decode_tokens(self, tokens: List[str]) -> str:
    method convert_token_to_id (line 50) | def convert_token_to_id(self, token):
    method convert_id_to_token (line 56) | def convert_id_to_token(self, index):
  class ChatGLMTokenizer (line 63) | class ChatGLMTokenizer(PreTrainedTokenizer):
    method __init__ (line 68) | def __init__(self, vocab_file, padding_side="right", **kwargs):
    method get_command (line 80) | def get_command(self, token):
    method pad_token (line 87) | def pad_token(self) -> str:
    method pad_token_id (line 91) | def pad_token_id(self):
    method vocab_size (line 95) | def vocab_size(self):
    method get_vocab (line 98) | def get_vocab(self):
    method _tokenize (line 104) | def _tokenize(self, text, **kwargs):
    method _convert_token_to_id (line 107) | def _convert_token_to_id(self, token):
    method _convert_id_to_token (line 111) | def _convert_id_to_token(self, index):
    method convert_tokens_to_string (line 115) | def convert_tokens_to_string(self, tokens: List[str]) -> str:
    method save_vocabulary (line 133) | def save_vocabulary(self, save_directory, filename_prefix=None):
    method get_prefix_tokens (line 161) | def get_prefix_tokens(self):
    method build_inputs_with_special_tokens (line 165) | def build_inputs_with_special_tokens(
    method _pad (line 191) | def _pad(

FILE: src/resources/models/chatglm3-6B/configuration_chatglm.py
  class ChatGLMConfig (line 4) | class ChatGLMConfig(PretrainedConfig):
    method __init__ (line 6) | def __init__(

FILE: src/resources/models/chatglm3-6B/modeling_chatglm.py
  function default_init (line 49) | def default_init(cls, *args, **kwargs):
  class InvalidScoreLogitsProcessor (line 53) | class InvalidScoreLogitsProcessor(LogitsProcessor):
    method __call__ (line 54) | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTen...
  class PrefixEncoder (line 61) | class PrefixEncoder(torch.nn.Module):
    method __init__ (line 68) | def __init__(self, config: ChatGLMConfig):
    method forward (line 84) | def forward(self, prefix: torch.Tensor):
  function split_tensor_along_last_dim (line 93) | def split_tensor_along_last_dim(
  class RotaryEmbedding (line 121) | class RotaryEmbedding(nn.Module):
    method __init__ (line 122) | def __init__(self, dim, original_impl=False, device=None, dtype=None):
    method forward_impl (line 129) | def forward_impl(
    method forward (line 154) | def forward(self, max_seq_len, offset=0):
  function apply_rotary_pos_emb (line 161) | def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> t...
  class RMSNorm (line 181) | class RMSNorm(torch.nn.Module):
    method __init__ (line 182) | def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None...
    method forward (line 187) | def forward(self, hidden_states: torch.Tensor):
  class CoreAttention (line 195) | class CoreAttention(torch.nn.Module):
    method __init__ (line 196) | def __init__(self, config: ChatGLMConfig, layer_number):
    method forward (line 221) | def forward(self, query_layer, key_layer, value_layer, attention_mask):
  class SelfAttention (line 313) | class SelfAttention(torch.nn.Module):
    method __init__ (line 320) | def __init__(self, config: ChatGLMConfig, layer_number, device=None):
    method _allocate_memory (line 349) | def _allocate_memory(self, inference_max_sequence_len, batch_size, dev...
    method forward (line 363) | def forward(
  function _config_to_kwargs (line 452) | def _config_to_kwargs(args):
  class MLP (line 459) | class MLP(torch.nn.Module):
    method __init__ (line 467) | def __init__(self, config: ChatGLMConfig, device=None):
    method forward (line 496) | def forward(self, hidden_states):
  class GLMBlock (line 505) | class GLMBlock(torch.nn.Module):
    method __init__ (line 512) | def __init__(self, config: ChatGLMConfig, layer_number, device=None):
    method forward (line 536) | def forward(
  class GLMTransformer (line 579) | class GLMTransformer(torch.nn.Module):
    method __init__ (line 582) | def __init__(self, config: ChatGLMConfig, device=None):
    method _get_layer (line 605) | def _get_layer(self, layer_number):
    method forward (line 608) | def forward(
  class ChatGLMPreTrainedModel (line 661) | class ChatGLMPreTrainedModel(PreTrainedModel):
    method _init_weights (line 673) | def _init_weights(self, module: nn.Module):
    method get_masks (line 677) | def get_masks(self, input_ids, past_key_values, padding_mask=None):
    method get_position_ids (line 695) | def get_position_ids(self, input_ids, device):
    method _set_gradient_checkpointing (line 700) | def _set_gradient_checkpointing(self, module, value=False):
  class Embedding (line 705) | class Embedding(torch.nn.Module):
    method __init__ (line 708) | def __init__(self, config: ChatGLMConfig, device=None):
    method forward (line 721) | def forward(self, input_ids):
  class ChatGLMModel (line 733) | class ChatGLMModel(ChatGLMPreTrainedModel):
    method __init__ (line 734) | def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
    method get_input_embeddings (line 768) | def get_input_embeddings(self):
    method get_prompt (line 771) | def get_prompt(self, batch_size, device, dtype=torch.half):
    method forward (line 786) | def forward(
    method quantize (line 845) | def quantize(self, weight_bit_width: int):
  class ChatGLMForConditionalGeneration (line 851) | class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
    method __init__ (line 852) | def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
    method _update_model_kwargs_for_generation (line 863) | def _update_model_kwargs_for_generation(
    method prepare_inputs_for_generation (line 894) | def prepare_inputs_for_generation(
    method forward (line 920) | def forward(
    method _reorder_cache (line 981) | def _reorder_cache(
    method process_response (line 999) | def process_response(self, output, history):
    method chat (line 1021) | def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] =...
    method stream_chat (line 1043) | def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, ...
    method stream_generate (line 1084) | def stream_generate(
    method quantize (line 1191) | def quantize(self, bits: int, empty_init=False, device=None, **kwargs):
  class ChatGLMForSequenceClassification (line 1210) | class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
    method __init__ (line 1211) | def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
    method forward (line 1227) | def forward(

FILE: src/resources/models/chatglm3-6B/quantization.py
  class Kernel (line 18) | class Kernel:
    method __init__ (line 19) | def __init__(self, code: bytes, function_names: List[str]):
  class W8A16Linear (line 44) | class W8A16Linear(torch.autograd.Function):
    method forward (line 46) | def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: to...
    method backward (line 58) | def backward(ctx, grad_output: torch.Tensor):
  function compress_int4_weight (line 67) | def compress_int4_weight(weight: torch.Tensor):  # (n, m)
  function extract_weight_to_half (line 88) | def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tenso...
  class QuantizedLinear (line 124) | class QuantizedLinear(torch.nn.Module):
    method __init__ (line 125) | def __init__(self, weight_bit_width: int, weight, bias=None, device="c...
    method forward (line 145) | def forward(self, input):
  function quantize (line 152) | def quantize(model, weight_bit_width, empty_init=False, device=None):

FILE: src/resources/models/chatglm3-6B/tokenization_chatglm.py
  class SPTokenizer (line 11) | class SPTokenizer:
    method __init__ (line 12) | def __init__(self, model_path: str):
    method tokenize (line 34) | def tokenize(self, s: str):
    method encode (line 37) | def encode(self, s: str, bos: bool = False, eos: bool = False) -> List...
    method decode (line 46) | def decode(self, t: List[int]) -> str:
    method decode_tokens (line 60) | def decode_tokens(self, tokens: List[str]) -> str:
    method convert_token_to_id (line 64) | def convert_token_to_id(self, token):
    method convert_id_to_token (line 70) | def convert_id_to_token(self, index):
  class ChatGLMTokenizer (line 79) | class ChatGLMTokenizer(PreTrainedTokenizer):
    method __init__ (line 84) | def __init__(self, vocab_file, padding_side="left", clean_up_tokenizat...
    method get_command (line 96) | def get_command(self, token):
    method unk_token (line 103) | def unk_token(self) -> str:
    method pad_token (line 107) | def pad_token(self) -> str:
    method pad_token_id (line 111) | def pad_token_id(self):
    method eos_token (line 115) | def eos_token(self) -> str:
    method eos_token_id (line 119) | def eos_token_id(self):
    method vocab_size (line 123) | def vocab_size(self):
    method get_vocab (line 126) | def get_vocab(self):
    method _tokenize (line 132) | def _tokenize(self, text, **kwargs):
    method _convert_token_to_id (line 135) | def _convert_token_to_id(self, token):
    method _convert_id_to_token (line 139) | def _convert_id_to_token(self, index):
    method convert_tokens_to_string (line 143) | def convert_tokens_to_string(self, tokens: List[str]) -> str:
    method save_vocabulary (line 146) | def save_vocabulary(self, save_directory, filename_prefix=None):
    method get_prefix_tokens (line 174) | def get_prefix_tokens(self):
    method build_single_message (line 178) | def build_single_message(self, role, metadata, message):
    method build_chat_input (line 185) | def build_chat_input(self, query, history=None, role="user"):
    method build_inputs_with_special_tokens (line 198) | def build_inputs_with_special_tokens(
    method _pad (line 223) | def _pad(

FILE: src/resources/models/glm-10B-chinese/configuration_glm.py
  class GLMConfig (line 28) | class GLMConfig(PretrainedConfig):
    method __init__ (line 93) | def __init__(

FILE: src/resources/models/glm-10B-chinese/modeling_glm.py
  function unscaled_init_method (line 52) | def unscaled_init_method(sigma):
  function scaled_init_method (line 61) | def scaled_init_method(mean, std, num_layers):
  function ensure_divisibility (line 71) | def ensure_divisibility(numerator, denominator):
  function divide (line 77) | def divide(numerator, denominator):
  function split_tensor_along_last_dim (line 84) | def split_tensor_along_last_dim(tensor, num_partitions,
  class MLP (line 105) | class MLP(torch.nn.Module):
    method __init__ (line 124) | def __init__(self, hidden_size, output_dropout_prob, init_method,
    method forward (line 140) | def forward(self, hidden_states):
  class VocabEmbedding (line 151) | class VocabEmbedding(torch.nn.Module):
    method __init__ (line 162) | def __init__(self, config):
    method forward (line 185) | def forward(self, input_):
  class PositionalEmbedding (line 194) | class PositionalEmbedding(torch.nn.Module):
    method __init__ (line 196) | def __init__(self, hidden_size):
    method forward (line 204) | def forward(self, pos_seq, bsz=None):
  class SelfAttention (line 214) | class SelfAttention(torch.nn.Module):
    method __init__ (line 241) | def __init__(self, hidden_size, num_attention_heads,
    method _transpose_for_scores (line 269) | def _transpose_for_scores(self, tensor):
    method forward (line 279) | def forward(self, hidden_states, ltor_mask, mem=None):
  class GLMBlock (line 346) | class GLMBlock(torch.nn.Module):
    method __init__ (line 375) | def __init__(self,
    method forward (line 413) | def forward(self, hidden_states, ltor_mask, mem=None):
  class GLMStack (line 434) | class GLMStack(torch.nn.Module):
    method __init__ (line 469) | def __init__(self,
    method forward (line 528) | def forward(self, hidden_states, position_ids, attention_mask, memory_...
    method update_mems (line 604) | def update_mems(self, hiddens, mems):
  class GLMPreTrainedModel (line 619) | class GLMPreTrainedModel(PreTrainedModel):
    method _init_weights (line 630) | def _init_weights(self, module):
    method _set_gradient_checkpointing (line 646) | def _set_gradient_checkpointing(self, module, value=False):
  class GLMModel (line 716) | class GLMModel(GLMPreTrainedModel):
    method __init__ (line 732) | def __init__(self, config):
    method forward (line 762) | def forward(
  class GLMForMultipleChoice (line 801) | class GLMForMultipleChoice(GLMPreTrainedModel):
    method __init__ (line 802) | def __init__(self, config):
    method forward (line 807) | def forward(
  class GLMForConditionalGeneration (line 843) | class GLMForConditionalGeneration(GLMPreTrainedModel):
    method __init__ (line 844) | def __init__(self, config):
    method _reorder_cache (line 849) | def _reorder_cache(self, past, beam_idx):
    method prepare_inputs_for_generation (line 861) | def prepare_inputs_for_generation(self, input_ids, past=None, position...
    method forward (line 894) | def forward(
  class GLMForSequenceClassification (line 921) | class GLMForSequenceClassification(GLMPreTrainedModel):
    method __init__ (line 922) | def __init__(self, config: GLMConfig, hidden_dropout=None, num_class=1):
    method forward (line 946) | def forward(self,

FILE: src/resources/models/glm-10B-chinese/tokenization_glm.py
  class GLMBatchEncoding (line 17) | class GLMBatchEncoding(BatchEncoding):
    method to (line 19) | def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
  class GLMTokenizerMixin (line 38) | class GLMTokenizerMixin:
    method sop_token (line 40) | def sop_token(self) -> Optional[str]:
    method sop_token_id (line 44) | def sop_token_id(self) -> Optional[int]:
    method eop_token (line 51) | def eop_token(self) -> Optional[str]:
    method eop_token_id (line 55) | def eop_token_id(self) -> Optional[int]:
    method gmask_token_id (line 62) | def gmask_token_id(self) -> int:
    method smask_token_id (line 66) | def smask_token_id(self) -> int:
    method mask_token_ids (line 70) | def mask_token_ids(self):
    method _build_input_for_multiple_choice (line 73) | def _build_input_for_multiple_choice(self, context, choices):
    method _pad_batch (line 110) | def _pad_batch(self, tokens, position_ids, attention_mask, max_seq_len...
    method _collate (line 122) | def _collate(self, samples):
    method build_inputs_for_multiple_choice (line 146) | def build_inputs_for_multiple_choice(self, model_input: BatchEncoding,...
    method build_inputs_for_generation (line 153) | def build_inputs_for_generation(self, model_input: BatchEncoding, max_...
  class GLMRobertaTokenizer (line 209) | class GLMRobertaTokenizer(RobertaTokenizer, GLMTokenizerMixin):
    method gmask_token_id (line 214) | def gmask_token_id(self) -> int:
    method smask_token_id (line 218) | def smask_token_id(self) -> int:
    method mask_token_ids (line 222) | def mask_token_ids(self):
  class GLMChineseTokenizer (line 226) | class GLMChineseTokenizer(PreTrainedTokenizer, GLMTokenizerMixin):
    method __init__ (line 230) | def __init__(self, vocab_file, **kwargs):
    method vocab_size (line 237) | def vocab_size(self):
    method get_vocab (line 240) | def get_vocab(self):
    method _tokenize (line 245) | def _tokenize(self, text, **kwargs):
    method _convert_token_to_id (line 248) | def _convert_token_to_id(self, token):
    method _convert_id_to_token (line 252) | def _convert_id_to_token(self, index):
    method convert_tokens_to_string (line 256) | def convert_tokens_to_string(self, tokens):
    method save_vocabulary (line 259) | def save_vocabulary(self, save_directory: str, filename_prefix: Option...
    method build_inputs_with_special_tokens (line 276) | def build_inputs_with_special_tokens(
  class GLMGPT2Tokenizer (line 308) | class GLMGPT2Tokenizer(GPT2Tokenizer, GLMTokenizerMixin):
    method build_inputs_with_special_tokens (line 312) | def build_inputs_with_special_tokens(
  class GLMBertTokenizer (line 334) | class GLMBertTokenizer(BertTokenizer, GLMTokenizerMixin):
    method gmask_token_id (line 339) | def gmask_token_id(self) -> int:
    method smask_token_id (line 343) | def smask_token_id(self) -> int:
    method mask_token_ids (line 347) | def mask_token_ids(self):
  class GLMTokenizer (line 351) | class GLMTokenizer:
    method from_pretrained (line 353) | def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwa...

FILE: src/resources/models/glm-350M-chinese/configuration_glm.py
  class GLMConfig (line 28) | class GLMConfig(PretrainedConfig):
    method __init__ (line 86) | def __init__(

FILE: src/resources/models/glm-350M-chinese/modeling_glm.py
  function unscaled_init_method (line 51) | def unscaled_init_method(sigma):
  function scaled_init_method (line 60) | def scaled_init_method(mean, std, num_layers):
  function ensure_divisibility (line 70) | def ensure_divisibility(numerator, denominator):
  function divide (line 76) | def divide(numerator, denominator):
  function split_tensor_along_last_dim (line 83) | def split_tensor_along_last_dim(tensor, num_partitions,
  class MLP (line 104) | class MLP(torch.nn.Module):
    method __init__ (line 121) | def __init__(self, hidden_size, output_dropout_prob, init_method,
    method forward (line 137) | def forward(self, hidden_states):
  class VocabEmbedding (line 148) | class VocabEmbedding(torch.nn.Module):
    method __init__ (line 158) | def __init__(self, config):
    method forward (line 181) | def forward(self, input_):
  class PositionalEmbedding (line 190) | class PositionalEmbedding(torch.nn.Module):
    method __init__ (line 192) | def __init__(self, hidden_size):
    method forward (line 200) | def forward(self, pos_seq, bsz=None):
  class SelfAttention (line 210) | class SelfAttention(torch.nn.Module):
    method __init__ (line 236) | def __init__(self, hidden_size, num_attention_heads,
    method _transpose_for_scores (line 264) | def _transpose_for_scores(self, tensor):
    method forward (line 274) | def forward(self, hidden_states, ltor_mask, mem=None):
  class GLMBlock (line 341) | class GLMBlock(torch.nn.Module):
    method __init__ (line 368) | def __init__(self,
    method forward (line 406) | def forward(self, hidden_states, ltor_mask, mem=None):
  class GLMStack (line 427) | class GLMStack(torch.nn.Module):
    method __init__ (line 460) | def __init__(self,
    method forward (line 519) | def forward(self, hidden_states, position_ids, attention_mask, memory_...
    method update_mems (line 595) | def update_mems(self, hiddens, mems):
  class GLMPreTrainedModel (line 610) | class GLMPreTrainedModel(PreTrainedModel):
    method _init_weights (line 621) | def _init_weights(self, module):
    method _set_gradient_checkpointing (line 637) | def _set_gradient_checkpointing(self, module, value=False):
  class GLMModel (line 697) | class GLMModel(GLMPreTrainedModel):
    method __init__ (line 711) | def __init__(self, config):
    method forward (line 741) | def forward(
  class GLMForMultipleChoice (line 782) | class GLMForMultipleChoice(GLMPreTrainedModel):
    method __init__ (line 783) | def __init__(self, config):
    method forward (line 788) | def forward(
  class GLMForConditionalGeneration (line 824) | class GLMForConditionalGeneration(GLMPreTrainedModel):
    method __init__ (line 825) | def __init__(self, config):
    method _reorder_cache (line 830) | def _reorder_cache(self, past, beam_idx):
    method prepare_inputs_for_generation (line 842) | def prepare_inputs_for_generation(self, input_ids, past=None, position...
    method forward (line 865) | def forward(

FILE: src/resources/models/glm-350M-chinese/tokenization_glm.py
  class GLMBatchEncoding (line 17) | class GLMBatchEncoding(BatchEncoding):
    method to (line 19) | def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
  class GLMTokenizerMixin (line 38) | class GLMTokenizerMixin:
    method sop_token (line 40) | def sop_token(self) -> Optional[str]:
    method sop_token_id (line 44) | def sop_token_id(self) -> Optional[int]:
    method eop_token (line 51) | def eop_token(self) -> Optional[str]:
    method eop_token_id (line 55) | def eop_token_id(self) -> Optional[int]:
    method gmask_token_id (line 62) | def gmask_token_id(self) -> int:
    method smask_token_id (line 66) | def smask_token_id(self) -> int:
    method mask_token_ids (line 70) | def mask_token_ids(self):
    method _build_input_for_multiple_choice (line 73) | def _build_input_for_multiple_choice(self, context, choices):
    method _pad_batch (line 110) | def _pad_batch(self, tokens, position_ids, attention_mask, max_seq_len...
    method _collate (line 122) | def _collate(self, samples):
    method build_inputs_for_multiple_choice (line 146) | def build_inputs_for_multiple_choice(self, model_input: BatchEncoding,...
    method build_inputs_for_generation (line 153) | def build_inputs_for_generation(self, model_input: BatchEncoding, max_...
  class GLMRobertaTokenizer (line 216) | class GLMRobertaTokenizer(RobertaTokenizer, GLMTokenizerMixin):
    method gmask_token_id (line 221) | def gmask_token_id(self) -> int:
    method smask_token_id (line 225) | def smask_token_id(self) -> int:
    method mask_token_ids (line 229) | def mask_token_ids(self):
  class GLMChineseTokenizer (line 233) | class GLMChineseTokenizer(PreTrainedTokenizer, GLMTokenizerMixin):
    method __init__ (line 237) | def __init__(self, vocab_file, **kwargs):
    method vocab_size (line 244) | def vocab_size(self):
    method get_vocab (line 247) | def get_vocab(self):
    method _tokenize (line 252) | def _tokenize(self, text, **kwargs):
    method _convert_token_to_id (line 255) | def _convert_token_to_id(self, token):
    method _convert_id_to_token (line 259) | def _convert_id_to_token(self, index):
    method convert_tokens_to_string (line 263) | def convert_tokens_to_string(self, tokens):
    method save_vocabulary (line 266) | def save_vocabulary(self, save_directory: str, filename_prefix: Option...
    method build_inputs_with_special_tokens (line 283) | def build_inputs_with_special_tokens(
  class GLMGPT2Tokenizer (line 315) | class GLMGPT2Tokenizer(GPT2Tokenizer, GLMTokenizerMixin):
    method build_inputs_with_special_tokens (line 319) | def build_inputs_with_special_tokens(
  class GLMBertTokenizer (line 341) | class GLMBertTokenizer(BertTokenizer, GLMTokenizerMixin):
    method gmask_token_id (line 346) | def gmask_token_id(self) -> int:
    method smask_token_id (line 350) | def smask_token_id(self) -> int:
    method mask_token_ids (line 354) | def mask_token_ids(self):
  class GLMTokenizer (line 358) | class GLMTokenizer:
    method from_pretrained (line 360) | def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwa...

FILE: src/resources/models/pangu-13B/configuration_gptpangu.py
  class GPTPanguConfig (line 4) | class GPTPanguConfig(PretrainedConfig):
    method __init__ (line 8) | def __init__(

FILE: src/resources/models/pangu-13B/modeling_gptpangu.py
  class GPTPanguAttention (line 19) | class GPTPanguAttention(nn.Module):
    method __init__ (line 20) | def __init__(self, config):
    method _attn (line 51) | def _attn(self, query, key, value, attention_mask=None, head_mask=None):
    method _split_heads (line 79) | def _split_heads(self, tensor, num_heads, attn_head_size):
    method _merge_heads (line 87) | def _merge_heads(self, tensor, num_heads, attn_head_size):
    method forward (line 95) | def forward(
  class GPTPanguMLP (line 136) | class GPTPanguMLP(nn.Module):
    method __init__ (line 137) | def __init__(self, intermediate_size, config):  # in MLP: intermediate...
    method forward (line 145) | def forward(self, hidden_states):
  class GPTPanguBlock (line 153) | class GPTPanguBlock(nn.Module):
    method __init__ (line 154) | def __init__(self, config):
    method forward (line 164) | def forward(
  class GPTPanguPreTrainedModel (line 204) | class GPTPanguPreTrainedModel(PreTrainedModel):
    method __init__ (line 214) | def __init__(self, *inputs, **kwargs):
    method _init_weights (line 217) | def _init_weights(self, module):
    method _set_gradient_checkpointing (line 244) | def _set_gradient_checkpointing(self, module, value=False):
  class GPTPanguModel (line 249) | class GPTPanguModel(GPTPanguPreTrainedModel):
    method __init__ (line 250) | def __init__(self, config):
    method get_input_embeddings (line 267) | def get_input_embeddings(self):
    method set_input_embeddings (line 270) | def set_input_embeddings(self, new_embeddings):
    method forward (line 273) | def forward(
  class GPTPanguForCausalLM (line 438) | class GPTPanguForCausalLM(GPTPanguPreTrainedModel):
    method __init__ (line 439) | def __init__(self, config):
    method get_output_embeddings (line 447) | def get_output_embeddings(self):
    method set_output_embeddings (line 450) | def set_output_embeddings(self, new_embeddings):
    method prepare_inputs_for_generation (line 453) | def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
    method forward (line 481) | def forward(
    method _reorder_cache (line 543) | def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.T...

FILE: src/resources/models/pangu-13B/tokenization_gptpangu.py
  class GPTPanguTokenizer (line 17) | class GPTPanguTokenizer(PreTrainedTokenizer):
    method __init__ (line 23) | def __init__(
    method build_inputs_with_special_tokens (line 37) | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=No...
    method tokenize (line 68) | def tokenize(self, text, **kwargs):
    method convert_tokens_to_ids (line 73) | def convert_tokens_to_ids(self, tokens):
    method _convert_token_to_id (line 99) | def _convert_token_to_id(self, token):
    method _convert_id_to_token (line 102) | def _convert_id_to_token(self, index):
    method convert_ids_to_tokens (line 105) | def convert_ids_to_tokens(self, ids):
    method decode (line 108) | def decode(self, ids, **kwargs):
    method vocab_size (line 121) | def vocab_size(self) -> int:
    method save_vocabulary (line 127) | def save_vocabulary(self, save_directory, filename_prefix=None):

FILE: src/resources/models/pangu-2.6B/configuration_gptpangu.py
  class GPTPanguConfig (line 4) | class GPTPanguConfig(PretrainedConfig):
    method __init__ (line 8) | def __init__(

FILE: src/resources/models/pangu-2.6B/modeling_gptpangu.py
  class GPTPanguAttention (line 19) | class GPTPanguAttention(nn.Module):
    method __init__ (line 20) | def __init__(self, config):
    method _attn (line 51) | def _attn(self, query, key, value, attention_mask=None, head_mask=None):
    method _split_heads (line 79) | def _split_heads(self, tensor, num_heads, attn_head_size):
    method _merge_heads (line 87) | def _merge_heads(self, tensor, num_heads, attn_head_size):
    method forward (line 95) | def forward(
  class GPTPanguMLP (line 136) | class GPTPanguMLP(nn.Module):
    method __init__ (line 137) | def __init__(self, intermediate_size, config):  # in MLP: intermediate...
    method forward (line 145) | def forward(self, hidden_states):
  class GPTPanguBlock (line 153) | class GPTPanguBlock(nn.Module):
    method __init__ (line 154) | def __init__(self, config):
    method forward (line 164) | def forward(
  class GPTPanguPreTrainedModel (line 204) | class GPTPanguPreTrainedModel(PreTrainedModel):
    method __init__ (line 214) | def __init__(self, *inputs, **kwargs):
    method _init_weights (line 217) | def _init_weights(self, module):
    method _set_gradient_checkpointing (line 244) | def _set_gradient_checkpointing(self, module, value=False):
  class GPTPanguModel (line 249) | class GPTPanguModel(GPTPanguPreTrainedModel):
    method __init__ (line 250) | def __init__(self, config):
    method get_input_embeddings (line 267) | def get_input_embeddings(self):
    method set_input_embeddings (line 270) | def set_input_embeddings(self, new_embeddings):
    method forward (line 273) | def forward(
  class GPTPanguForCausalLM (line 438) | class GPTPanguForCausalLM(GPTPanguPreTrainedModel):
    method __init__ (line 439) | def __init__(self, config):
    method get_output_embeddings (line 447) | def get_output_embeddings(self):
    method set_output_embeddings (line 450) | def set_output_embeddings(self, new_embeddings):
    method prepare_inputs_for_generation (line 453) | def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
    method forward (line 481) | def forward(
    method _reorder_cache (line 543) | def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.T...

FILE: src/resources/models/pangu-2.6B/tokenization_gptpangu.py
  class GPTPanguTokenizer (line 17) | class GPTPanguTokenizer(PreTrainedTokenizer):
    method __init__ (line 23) | def __init__(
    method build_inputs_with_special_tokens (line 37) | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=No...
    method tokenize (line 68) | def tokenize(self, text, **kwargs):
    method convert_tokens_to_ids (line 73) | def convert_tokens_to_ids(self, tokens):
    method _convert_token_to_id (line 99) | def _convert_token_to_id(self, token):
    method _convert_id_to_token (line 102) | def _convert_id_to_token(self, index):
    method convert_ids_to_tokens (line 105) | def convert_ids_to_tokens(self, ids):
    method decode (line 108) | def decode(self, ids, **kwargs):
    method vocab_size (line 121) | def vocab_size(self) -> int:
    method save_vocabulary (line 127) | def save_vocabulary(self, save_directory, filename_prefix=None):

FILE: src/resources/models/pangu-350M/configuration_gptpangu.py
  class GPTPanguConfig (line 4) | class GPTPanguConfig(PretrainedConfig):
    method __init__ (line 8) | def __init__(

FILE: src/resources/models/pangu-350M/modeling_gptpangu.py
  class GPTPanguAttention (line 19) | class GPTPanguAttention(nn.Module):
    method __init__ (line 20) | def __init__(self, config):
    method _attn (line 51) | def _attn(self, query, key, value, attention_mask=None, head_mask=None):
    method _split_heads (line 79) | def _split_heads(self, tensor, num_heads, attn_head_size):
    method _merge_heads (line 87) | def _merge_heads(self, tensor, num_heads, attn_head_size):
    method forward (line 95) | def forward(
  class GPTPanguMLP (line 136) | class GPTPanguMLP(nn.Module):
    method __init__ (line 137) | def __init__(self, intermediate_size, config):  # in MLP: intermediate...
    method forward (line 145) | def forward(self, hidden_states):
  class GPTPanguBlock (line 153) | class GPTPanguBlock(nn.Module):
    method __init__ (line 154) | def __init__(self, config):
    method forward (line 164) | def forward(
  class GPTPanguPreTrainedModel (line 204) | class GPTPanguPreTrainedModel(PreTrainedModel):
    method __init__ (line 214) | def __init__(self, *inputs, **kwargs):
    method _init_weights (line 217) | def _init_weights(self, module):
    method _set_gradient_checkpointing (line 244) | def _set_gradient_checkpointing(self, module, value=False):
  class GPTPanguModel (line 249) | class GPTPanguModel(GPTPanguPreTrainedModel):
    method __init__ (line 250) | def __init__(self, config):
    method get_input_embeddings (line 267) | def get_input_embeddings(self):
    method set_input_embeddings (line 270) | def set_input_embeddings(self, new_embeddings):
    method forward (line 273) | def forward(
  class GPTPanguForCausalLM (line 438) | class GPTPanguForCausalLM(GPTPanguPreTrainedModel):
    method __init__ (line 439) | def __init__(self, config):
    method get_output_embeddings (line 447) | def get_output_embeddings(self):
    method set_output_embeddings (line 450) | def set_output_embeddings(self, new_embeddings):
    method prepare_inputs_for_generation (line 453) | def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
    method forward (line 481) | def forward(
    method _reorder_cache (line 543) | def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.T...

FILE: src/resources/models/pangu-350M/tokenization_gptpangu.py
  class GPTPanguTokenizer (line 17) | class GPTPanguTokenizer(PreTrainedTokenizer):
    method __init__ (line 23) | def __init__(
    method build_inputs_with_special_tokens (line 37) | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=No...
    method tokenize (line 68) | def tokenize(self, text, **kwargs):
    method convert_tokens_to_ids (line 73) | def convert_tokens_to_ids(self, tokens):
    method _convert_token_to_id (line 99) | def _convert_token_to_id(self, token):
    method _convert_id_to_token (line 102) | def _convert_id_to_token(self, index):
    method convert_ids_to_tokens (line 105) | def convert_ids_to_tokens(self, ids):
    method decode (line 108) | def decode(self, ids, **kwargs):
    method vocab_size (line 121) | def vocab_size(self) -> int:
    method save_vocabulary (line 127) | def save_vocabulary(self, save_directory, filename_prefix=None):

FILE: src/resources/models/qwen-7B-chat/configuration_qwen.py
  class QWenConfig (line 9) | class QWenConfig(PretrainedConfig):
    method __init__ (line 19) | def __init__(

FILE: src/resources/models/qwen-7B-chat/modeling_qwen.py
  class FlashSelfAttention (line 76) | class FlashSelfAttention(torch.nn.Module):
    method __init__ (line 77) | def __init__(
    method forward (line 94) | def forward(self, q, k, v):
  class QWenAttention (line 140) | class QWenAttention(nn.Module):
    method __init__ (line 141) | def __init__(self, config, layer_number=None):
    method _attn (line 213) | def _attn(self, query, key, value, attention_mask=None, head_mask=None):
    method _upcast_and_reordered_attn (line 249) | def _upcast_and_reordered_attn(
    method _split_heads (line 305) | def _split_heads(self, tensor, num_heads, attn_head_size):
    method _merge_heads (line 310) | def _merge_heads(self, tensor, num_heads, attn_head_size):
    method forward (line 315) | def forward(
  class QWenMLP (line 409) | class QWenMLP(nn.Module):
    method __init__ (line 410) | def __init__(self, config):
    method forward (line 421) | def forward(self, hidden_states):
  class QWenBlock (line 429) | class QWenBlock(nn.Module):
    method __init__ (line 430) | def __init__(self, config, layer_idx=None, num_expert=1):
    method forward (line 455) | def forward(
  class QWenPreTrainedModel (line 504) | class QWenPreTrainedModel(PreTrainedModel):
    method __init__ (line 511) | def __init__(self, *inputs, **kwargs):
    method _init_weights (line 514) | def _init_weights(self, module):
    method _set_gradient_checkpointing (line 537) | def _set_gradient_checkpointing(self, module, value=False):
  class QWenModel (line 542) | class QWenModel(QWenPreTrainedModel):
    method __init__ (line 545) | def __init__(self, config):
    method get_input_embeddings (line 583) | def get_input_embeddings(self):
    method set_input_embeddings (line 586) | def set_input_embeddings(self, new_embeddings):
    method forward (line 589) | def forward(
  class QWenLMHeadModel (line 745) | class QWenLMHeadModel(QWenPreTrainedModel):
    method __init__ (line 749) | def __init__(self, config):
    method get_output_embeddings (line 762) | def get_output_embeddings(self):
    method set_output_embeddings (line 765) | def set_output_embeddings(self, new_embeddings):
    method prepare_inputs_for_generation (line 768) | def prepare_inputs_for_generation(
    method forward (line 804) | def forward(
    method _reorder_cache (line 868) | def _reorder_cache(
    method chat (line 880) | def chat(
    method generate (line 926) | def generate(
  class RotaryEmbedding (line 964) | class RotaryEmbedding(torch.nn.Module):
    method __init__ (line 965) | def __init__(self, dim, base=10000):
    method update_rotary_pos_emb_cache (line 977) | def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha...
    method forward (line 991) | def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):
  function _rotate_half (line 996) | def _rotate_half(x):
  function apply_rotary_pos_emb (line 1004) | def apply_rotary_pos_emb(t, freqs, use_flash_rotary=False):
  class RMSNorm (line 1021) | class RMSNorm(torch.nn.Module):
    method __init__ (line 1022) | def __init__(self, dim: int, eps: float = 1e-6):
    method _norm (line 1027) | def _norm(self, x):
    method forward (line 1030) | def forward(self, x):

FILE: src/resources/models/qwen-7B-chat/qwen_generation_utils.py
  function pad_batch (line 25) | def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> B...
  function get_ltor_masks_and_position_ids (line 33) | def get_ltor_masks_and_position_ids(
  function get_batch (line 94) | def get_batch(context_tokens: torch.LongTensor, eod_id: int):
  function get_stop_words_ids (line 109) | def get_stop_words_ids(chat_format, tokenizer):
  function make_context (line 119) | def make_context(
  function _decode_default (line 192) | def _decode_default(
  function _decode_chatml (line 224) | def _decode_chatml(
  function decode_tokens (line 259) | def decode_tokens(
  class StopWordsLogitsProcessor (line 296) | class StopWordsLogitsProcessor(LogitsProcessor):
    method __init__ (line 309) | def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_...
    method __call__ (line 343) | def __call__(
    method _tokens_match (line 352) | def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[in...
    method _calc_stopped_samples (line 365) | def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iter...
  function top_k_logits (line 379) | def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
  function switch (line 409) | def switch(val1, val2, boolean):

FILE: src/resources/models/qwen-7B-chat/tokenization_qwen.py
  class QWenTokenizer (line 26) | class QWenTokenizer(PreTrainedTokenizer):
    method __init__ (line 33) | def __init__(
    method __len__ (line 134) | def __len__(self):
    method get_vocab (line 137) | def get_vocab(self):
    method convert_tokens_to_ids (line 140) | def convert_tokens_to_ids(self, tokens):
    method save_vocabulary (line 163) | def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
    method tokenize (line 177) | def tokenize(self, text: str, **kwargs) -> List[str]:
    method convert_tokens_to_string (line 197) | def convert_tokens_to_string(self, tokens: List[str]) -> str:
    method vocab_size (line 209) | def vocab_size(self):
    method _convert_id_to_token (line 212) | def _convert_id_to_token(self, index: int) -> str:
    method _convert_token_to_id (line 217) | def _convert_token_to_id(self, token: str) -> int:
    method all_special_tokens (line 222) | def all_special_tokens(self) -> List[str]:
    method all_special_ids (line 232) | def all_special_ids(self) -> List[int]:
    method _tokenize (line 239) | def _tokenize(self, text, **kwargs):
    method _decode (line 248) | def _decode(
    method build_inputs_with_special_tokens (line 260) | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=No...

FILE: src/train_dpo.py
  function preprocess_logits_for_metrics (line 27) | def preprocess_logits_for_metrics(logits, labels):
  function get_parser (line 34) | def get_parser():
  function main (line 110) | def main():

FILE: src/train_reward.py
  function get_parser (line 24) | def get_parser():
  function main (line 92) | def main():

FILE: src/train_rlhf.py
  function get_parser (line 26) | def get_parser():
  function create_datasets (line 144) | def create_datasets(args, tokenizer_padding_from_left, ppo_ptx_enabled, ...
  function create_dataloader (line 165) | def create_dataloader(args, train_dataset, pretrain_dataset=None):
  function main (line 203) | def main():

FILE: src/train_rlhf_trlx.py
  function get_parser (line 24) | def get_parser():
  function train (line 86) | def train(model_path: Optional[str] = None,
  function main (line 166) | def main():

FILE: src/train_sft.py
  function preprocess_logits_for_metrics (line 26) | def preprocess_logits_for_metrics(logits, labels):
  function get_parser (line 33) | def get_parser():
  function main (line 107) | def main():

FILE: src/train_sft_wo_trainer.py
  function preprocess_logits_for_metrics (line 27) | def preprocess_logits_for_metrics(logits, labels):
  function get_parser (line 34) | def get_parser():
  function main (line 106) | def main():

FILE: src/utils/config.py
  function merge (line 18) | def merge(base: Dict, update: Dict, updated: Set) -> Dict:
  function _merge_dicts (line 31) | def _merge_dicts(base: Dict, update: Dict) -> Dict:
  class ModelConfig (line 46) | class ModelConfig:
    method from_dict (line 81) | def from_dict(cls, config: Dict[str, Any]):
  class TokenizerConfig (line 86) | class TokenizerConfig:
    method from_dict (line 105) | def from_dict(cls, config: Dict[str, Any]):
  class OptimizerConfig (line 110) | class OptimizerConfig:
    method from_dict (line 125) | def from_dict(cls, config: Dict[str, Any]):
  class SchedulerConfig (line 130) | class SchedulerConfig:
    method from_dict (line 145) | def from_dict(cls, config: Dict[str, Any]):
  class TrainConfig (line 150) | class TrainConfig:
    method from_dict (line 239) | def from_dict(cls, config: Dict[str, Any]):
  class TRLConfig (line 244) | class TRLConfig:
    method load_yaml (line 257) | def load_yaml(cls, yml_fp: str):
    method to_dict (line 268) | def to_dict(self):
    method evolve (line 283) | def evolve(self, **kwargs) -> "TRLConfig":
    method from_dict (line 294) | def from_dict(cls, config: Dict):
    method update (line 308) | def update(cls, baseconfig: Dict, config: Dict):
    method __str__ (line 321) | def __str__(self):
  function default_ppo_config (line 328) | def default_ppo_config():
  function default_ilql_config (line 373) | def default_ilql_config():
  function default_sft_config (line 408) | def default_sft_config():
  function get_train_ds_config (line 435) | def get_train_ds_config(global_batch_size=32,
  function get_eval_ds_config (line 485) | def get_eval_ds_config(global_batch_size=32, micro_batch_size=4, gradien...

FILE: src/utils/file_utils.py
  function set_seed (line 14) | def set_seed(seed_val=42):
  function print_rank_0 (line 21) | def print_rank_0(*message):
  function significant (line 29) | def significant(x: Number, ndigits=2) -> Number:
  function print_gpu_utilization (line 53) | def print_gpu_utilization(prefix: str = "", index: int = 0, only_rank_0:...
  function print_gpu_utilization_torch (line 65) | def print_gpu_utilization_torch(prefix: str = "", index: int = 0, only_r...
  function print_trainable_parameters (line 85) | def print_trainable_parameters(model):

FILE: src/utils/loading.py
  function _trainer_unavailble (line 38) | def _trainer_unavailble(name):
  function prepare_decoder_attention_mask (line 47) | def prepare_decoder_attention_mask(self, *args, **kwargs):
  function chatglm_auto_configure_device_map (line 51) | def chatglm_auto_configure_device_map(num_gpus: int, model_name: str, lo...
  function llama_and_baichuan_auto_configure_device_map (line 98) | def llama_and_baichuan_auto_configure_device_map(num_gpus: int, model_na...
  function load_params_8bit_or_4bit (line 132) | def load_params_8bit_or_4bit(args, model: PreTrainedModel) -> Dict:
  function load_tokenizer_and_model (line 166) | def load_tokenizer_and_model(args, with_trainer: bool = True) -> Tuple[P...
  function to_peft_model (line 285) | def to_peft_model(args, model: PreTrainedModel) -> PreTrainedModel:
  function load_checkpoint (line 320) | def load_checkpoint(args, model: PreTrainedModel, strict: bool = True) -...
  function get_trainer (line 329) | def get_trainer(name: str) -> Callable:
  function get_pipeline (line 340) | def get_pipeline(name: str) -> Callable:

FILE: src/utils/method_configs.py
  function register_method (line 20) | def register_method(name):
  class MethodConfig (line 44) | class MethodConfig:
    method from_dict (line 55) | def from_dict(cls, config: Dict[str, Any]):
  function get_method (line 59) | def get_method(name: str) -> MethodConfig:
  class PPOConfig (line 73) | class PPOConfig(MethodConfig):
    method get_advantages_and_returns (line 133) | def get_advantages_and_returns(
    method loss (line 171) | def loss(
  class SFTConfig (line 238) | class SFTConfig(MethodConfig):
  class ILQLConfig (line 251) | class ILQLConfig(MethodConfig):
    method loss (line 262) | def loss(self, outputs, labels):

FILE: src/utils/modeling_utils.py
  function _make_causal_mask (line 41) | def _make_causal_mask(
  function _expand_mask (line 59) | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Option...
  function _prepare_decoder_attention_mask (line 73) | def _prepare_decoder_attention_mask(attention_mask, input_shape, input_e...
  function get_distributed_config (line 102) | def get_distributed_config(accelerator: Accelerator):
  class OptimizerName (line 127) | class OptimizerName(str, Enum):
  function get_optimizer_class (line 137) | def get_optimizer_class(name: OptimizerName):
  class SchedulerName (line 174) | class SchedulerName(str, Enum):
  function get_scheduler_class (line 181) | def get_scheduler_class(name: SchedulerName):
  class Clock (line 193) | class Clock:
    method __init__ (line 198) | def __init__(self):
    method tick (line 203) | def tick(self, samples: int = 0) -> float:
    method get_stat (line 219) | def get_stat(self, n_samp: int = 1000, reset: bool = False):
  function tree_map (line 234) | def tree_map(f, tree: Any) -> Any:
  function to_device (line 248) | def to_device(tree, device, non_blocking=False):
  function filter_non_scalars (line 255) | def filter_non_scalars(xs: Dict) -> Dict:
  function get_git_tag (line 269) | def get_git_tag() -> Tuple[str, str]:
  function make_head (line 281) | def make_head(n_embd: int, out: int, dtype: type = torch.float32) -> nn....
  function freeze_bottom_causal_layers (line 290) | def freeze_bottom_causal_layers(model: nn.Module, num_layers_unfrozen: i...
  function freeze_bottom_seq2seq_layers (line 303) | def freeze_bottom_seq2seq_layers(model: nn.Module, num_layers_unfrozen: ...
  function rhasattr (line 325) | def rhasattr(obj, attr):
  function rgetattr (line 341) | def rgetattr(obj, attr: str, *args) -> object:
  function findattr (line 354) | def findattr(obj, attrs: Tuple[str]) -> Union[object, None]:
  function hf_get_decoder (line 361) | def hf_get_decoder(model: nn.Module) -> nn.Module:
  function hf_get_decoder_final_norm (line 374) | def hf_get_decoder_final_norm(model: nn.Module) -> float:
  function hf_get_decoder_blocks (line 390) | def hf_get_decoder_blocks(model: nn.Module) -> Tuple[nn.Module]:
  function hf_get_lm_head (line 411) | def hf_get_lm_head(model: nn.Module) -> nn.Module:
  function hf_get_hidden_size (line 421) | def hf_get_hidden_size(config: transformers.PretrainedConfig) -> int:
  function hf_get_num_hidden_layers (line 433) | def hf_get_num_hidden_layers(config: transformers.PretrainedConfig) -> int:
  function get_global_statistics (line 445) | def get_global_statistics(xs: torch.Tensor) -> Tuple[float, float, int]:
  function whiten (line 460) | def whiten(xs: torch.Tensor, shift_mean=True, distributed=True) -> torch...
  function logprobs_of_labels (line 473) | def logprobs_of_labels(logits, labels):
  function flatten_dict (line 482) | def flatten_dict(
  function get_tensor_stats (line 498) | def get_tensor_stats(xs: torch.Tensor, mask: torch.Tensor, n: int):
  class RunningMoments (line 508) | class RunningMoments:
    method __init__ (line 509) | def __init__(self):
    method update (line 519) | def update(self, xs: torch.Tensor) -> Tuple[float, float]:
  function generate_layer_regex (line 629) | def generate_layer_regex(config: transformers.PretrainedConfig, num_laye...
  function get_delta_modified_modules (line 641) | def get_delta_modified_modules(
  function get_delta_model_class (line 658) | def get_delta_model_class(model_type: str):
  function parse_delta_kwargs (line 671) | def parse_delta_kwargs(
  function regex_for_range (line 705) | def regex_for_range(min_: int, max_: int) -> str:  # noqa
  function get_optimizer_grouped_parameters (line 782) | def get_optimizer_grouped_parameters(model,
  function get_all_reduce_mean (line 808) | def get_all_reduce_mean(tensor):
  function moving_average (line 814) | def moving_average(model, model_ema, beta=0.992, device=None, zero_stage...
  function save_hf_format (line 831) | def save_hf_format(model, tokenizer, args, sub_folder=""):
  function _z3_params_to_fetch (line 852) | def _z3_params_to_fetch(param_list):
  function save_zero_three_model (line 859) | def save_zero_three_model(model_ema, global_rank, save_dir, zero_stage=0):
  function sorted_checkpoints (line 886) | def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint", ...
  function rotate_checkpoints (line 905) | def rotate_checkpoints(save_total_limit, use_mtime=False, output_dir=Non...
  function qwen_make_context (line 931) | def qwen_make_context(

FILE: src/utils/nlp_utils.py
  function clean_text (line 7) | def clean_text(text):
Condensed preview — 135 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,499K chars).
[
  {
    "path": "README.md",
    "chars": 33700,
    "preview": "## Features\n\n提供3大功能:\n- LLM模型预训练:支持常见模型的预训练,包括:decoder结构(LLaMA、GPT)、encoder结构(GLM)\n- LLM模型评测:参考GPT类模型,基于ZeroShot和FewShot实"
  },
  {
    "path": "docs/INSTALL_ON_WINDOWS.md",
    "chars": 7201,
    "preview": "大模型相关的常见package在***Windows***上的安装方案。在以下配置上已测试成功: ```windows11 + python3.11 + cuda12.3 + cudnn8.9.5.30 + torch2.1.0+cu121"
  },
  {
    "path": "examples/data_prepare.sh",
    "chars": 638,
    "preview": "#!/bin/bash\n\nMODEL=\"pangu-2.6B\"\n\nROOT=\"/mnt/sfevol775196/sunzeye273\"\n#ROOT=\"/mnt/share-pa002-vol682688-prd/sunzeye273\"\n#"
  },
  {
    "path": "examples/eval_pretrain.sh",
    "chars": 2941,
    "preview": "#!/bin/bash\n\n# C-Eval and MMLU benchamarks\nTASK=\"ceval\"\nMODEL=\"llama-7B\"\nROOT=\"/mnt/sfevol775196/sunzeye273\"\n#ROOT=\"/mnt"
  },
  {
    "path": "examples/eval_reward.sh",
    "chars": 1146,
    "preview": "#!/bin/bash\n\nMODEL=\"chatglm-6B\"\n\nROOT=\"/mnt/sfevol775196/sunzeye273\"\n#ROOT=\"/mnt/share-pa002-vol682688-prd/sunzeye273\"\n#"
  },
  {
    "path": "examples/pred_pretrain.sh",
    "chars": 1123,
    "preview": "#!/bin/bash\n\nMODEL=\"llama-7B\"\n\nROOT=\"/mnt/sfevol775196/sunzeye273\"\n#ROOT=\"/mnt/share-pa002-vol682688-prd/sunzeye273\"\n#RO"
  },
  {
    "path": "examples/pred_reward.sh",
    "chars": 1217,
    "preview": "#!/bin/bash\n\nMODEL=\"chatglm2-6B\"\n\nROOT=\"/mnt/sfevol775196/sunzeye273\"\n#ROOT=\"/mnt/share-pa002-vol682688-prd/sunzeye273\"\n"
  },
  {
    "path": "examples/pretrain.sh",
    "chars": 1399,
    "preview": "#!/bin/bash\n\nMODEL=\"llama-7B\"\n\nROOT=\"/mnt/sfevol775196/sunzeye273\"\n#ROOT=\"/mnt/share-pa002-vol682688-prd/sunzeye273\"\n#RO"
  },
  {
    "path": "examples/train_dpo.sh",
    "chars": 2090,
    "preview": "#!/bin/bash\n\nMODEL=\"chatglm2-6B\"\n\n#ROOT=\"/mnt/sfevol775196/sunzeye273\"\nROOT=\"/mnt/pa002-28359-vol543625-private\"\n#ROOT=\""
  },
  {
    "path": "examples/train_reward.sh",
    "chars": 1439,
    "preview": "#!/bin/bash\n\nMODEL=\"chatglm-6B\"\n\nROOT=\"/mnt/sfevol775196/sunzeye273\"\n#ROOT=\"/mnt/share-pa002-vol682688-prd/sunzeye273\"\n#"
  },
  {
    "path": "examples/train_rlhf.sh",
    "chars": 2010,
    "preview": "#!/bin/bash\n\nREWARD_MODEL=\"pangu-350M\"\nSFT_MODEL=\"pangu-2.6B\"\n\nROOT=\"/mnt/sfevol775196/sunzeye273\"\n#ROOT=\"/mnt/share-pa0"
  },
  {
    "path": "examples/train_sft.sh",
    "chars": 1441,
    "preview": "#!/bin/bash\n\nMODEL=\"chatglm2-6B\"\n\n#ROOT=\"/mnt/sfevol775196/sunzeye273\"\nROOT=\"/mnt/pa002-28359-vol543625-private\"\n#ROOT=\""
  },
  {
    "path": "notebook/data_prepare.ipynb",
    "chars": 26369,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": "
  },
  {
    "path": "notebook/reward.ipynb",
    "chars": 13526,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"a7d1682d\",\n   \"metadata\": {},\n   \"output"
  },
  {
    "path": "notebook/rlhf.ipynb",
    "chars": 13955,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"e0fa8199\",\n   \"metadata\": {},\n   \"output"
  },
  {
    "path": "notebook/sft.ipynb",
    "chars": 19920,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"6f526208\",\n   \"metadata\": {},\n   \"output"
  },
  {
    "path": "notebook/test.ipynb",
    "chars": 19303,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"93c83ce9\",\n   \"metadata\": {},\n   \"output"
  },
  {
    "path": "requirements.txt",
    "chars": 351,
    "preview": "jieba\nsentencepiece\nnltk>=3.8.1\nrouge-score>=0.1.2\nevaluate>=0.4.0\nnumpy==1.23.2\ntensorflow>=2.5.0\ntransformers>=4.31.0\n"
  },
  {
    "path": "src/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/data/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/data/data.py",
    "chars": 96659,
    "preview": "import os\nimport json\nimport re\nimport random\nfrom typing import Tuple, List\n\nimport torch\nimport pandas as pd\n\nfrom tqd"
  },
  {
    "path": "src/data/data_types.py",
    "chars": 3945,
    "preview": "\nfrom dataclasses import dataclass\nfrom typing import Iterable\n\nfrom torchtyping import TensorType\n\n\n@dataclass\nclass Pr"
  },
  {
    "path": "src/data/pipeline.py",
    "chars": 11653,
    "preview": "\nimport os\nimport sys\nimport time\nimport json\nimport torch\n\nfrom abc import abstractmethod\nfrom typing import Optional, "
  },
  {
    "path": "src/data_prepare.py",
    "chars": 18384,
    "preview": "\nimport sys\nsys.path.insert(0, \"/root/autodl-tmp/Code/RLHF\")\nsys.path.insert(0, \"/mnt/sfevol775196/sunzeye273/Code/chatg"
  },
  {
    "path": "src/eval_pretrain.py",
    "chars": 25734,
    "preview": "\nimport sys\nsys.path.insert(0, \"/root/autodl-tmp/Code/RLHF\")\nsys.path.insert(0, \"/mnt/sfevol775196/sunzeye273/Code/chatg"
  },
  {
    "path": "src/models/__init__.py",
    "chars": 164,
    "preview": "from src.models.reward import RewardModel\nfrom src.models.loss import PairWiseLoss\n# from src.models.llama import LlamaF"
  },
  {
    "path": "src/models/loss.py",
    "chars": 368,
    "preview": "\nimport torch\nimport torch.nn as nn\n\n\nclass PairWiseLoss(nn.Module):\n    \"\"\"\n    Pairwise Loss for Reward Model\n    \"\"\"\n"
  },
  {
    "path": "src/models/ppo.py",
    "chars": 62415,
    "preview": "# Copyright 2022 CarperAI & The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2."
  },
  {
    "path": "src/models/reward.py",
    "chars": 5853,
    "preview": "\nimport torch\nfrom torch import nn\nfrom transformers.modeling_utils import PreTrainedModel\n\n\nclass RewardModel(PreTraine"
  },
  {
    "path": "src/models/rlhf_engine.py",
    "chars": 18201,
    "preview": "# Copyright (c) Microsoft Corporation.\n# SPDX-License-Identifier: Apache-2.0\n\nimport os\nimport time\nimport torch\nimport "
  },
  {
    "path": "src/models/trainer.py",
    "chars": 103832,
    "preview": "from collections import defaultdict\n\nimport deepspeed\nimport sys\nimport json\nimport os\nimport ray\nimport torch\nimport lo"
  },
  {
    "path": "src/pretrain.py",
    "chars": 19587,
    "preview": "\nimport sys\nsys.path.insert(0, \"/root/autodl-tmp/Code/RLHF\")\nsys.path.insert(0, \"/mnt/sfevol775196/sunzeye273/Code/chatg"
  },
  {
    "path": "src/pretrain_wo_trainer.py",
    "chars": 21319,
    "preview": "\nimport sys\nsys.path.insert(0, \"/root/autodl-tmp/Code/RLHF\")\nsys.path.insert(0, \"/mnt/sfevol775196/sunzeye273/Code/chatg"
  },
  {
    "path": "src/resources/config/deepspeed/Experiment.md",
    "chars": 9560,
    "preview": "\n为验证不同预训练模型使用deepspeed的训练效率是否能达到官方宣称的效果(加速、节省GPU等),进行了benchmarking\n- 实验场景:SFT阶段训练\n- 实验数据:SFT & Reward Data的验证集,共1万条样本\n- "
  },
  {
    "path": "src/resources/config/deepspeed/README.md",
    "chars": 12643,
    "preview": "# Deepspeed Configuration\n- deepspeed==0.8.2\n- transformers==4.26.1\n\n- using ```transformers.Trainer``` and ```transform"
  },
  {
    "path": "src/resources/config/deepspeed/stage-1.json",
    "chars": 643,
    "preview": "{\n  \"train_batch_size\": \"auto\",\n  \"train_micro_batch_size_per_gpu\": \"auto\",\n  \"gradient_accumulation_steps\": \"auto\",\n  \""
  },
  {
    "path": "src/resources/config/deepspeed/stage-2.json",
    "chars": 949,
    "preview": "{\n  \"train_batch_size\": \"auto\",\n  \"train_micro_batch_size_per_gpu\": \"auto\",\n  \"gradient_accumulation_steps\": \"auto\",\n  \""
  },
  {
    "path": "src/resources/config/deepspeed/stage-3-no_trainer.json",
    "chars": 1492,
    "preview": "{\n  \"train_batch_size\": \"auto\",\n  \"train_micro_batch_size_per_gpu\": \"auto\",\n  \"gradient_accumulation_steps\": \"auto\",\n  \""
  },
  {
    "path": "src/resources/config/deepspeed/stage-3.json",
    "chars": 1158,
    "preview": "{\n  \"train_batch_size\": \"auto\",\n  \"train_micro_batch_size_per_gpu\": \"auto\",\n  \"gradient_accumulation_steps\": \"auto\",\n  \""
  },
  {
    "path": "src/resources/config/log.conf",
    "chars": 551,
    "preview": "[loggers]\nkeys=root,main\n\n[handlers]\nkeys=consoleHanlder,fileHandler\n\n[formatters]\nkeys=fmt\n\n[logger_root]\nlevel=DEBUG\nh"
  },
  {
    "path": "src/resources/config/ppo_model/ppo_config_glm-350M-chinese.yml",
    "chars": 1061,
    "preview": "train:\n  seq_length: 512\n  epochs: 1\n  total_steps: 50000\n  batch_size: 1\n  checkpoint_interval: 1000\n  eval_interval: 1"
  },
  {
    "path": "src/resources/config/ppo_model/ppo_config_glm-small.yml",
    "chars": 1054,
    "preview": "train:\n  seq_length: 256\n  epochs: 1\n  total_steps: 10000\n  batch_size: 2\n  checkpoint_interval: 100\n  eval_interval: 10"
  },
  {
    "path": "src/resources/config/ppo_model/ppo_config_pangu-2.6B.yml",
    "chars": 1021,
    "preview": "train:\n  seq_length: 512\n  epochs: 1\n  total_steps: 50000\n  batch_size: 4\n  checkpoint_interval: 1000\n  eval_interval: 1"
  },
  {
    "path": "src/resources/config/ppo_model/ppo_config_pangu-350M.yml",
    "chars": 1024,
    "preview": "train:\n  seq_length: 512\n  epochs: 1\n  total_steps: 100000\n  batch_size: 16\n  checkpoint_interval: 1000\n  eval_interval:"
  },
  {
    "path": "src/resources/dict/jieba_custom_dict.txt",
    "chars": 32,
    "preview": "<s>\n</s>\n<eot>\n<unk>\n<sep>\n<pad>"
  },
  {
    "path": "src/resources/eval/ceval/subject_mapping.json",
    "chars": 4460,
    "preview": "{\n  \"computer_network\": [\n    \"Computer Network\",\n    \"计算机网络\",\n    \"STEM\"\n  ],\n  \"operating_system\": [\n    \"Operating Sy"
  },
  {
    "path": "src/resources/eval/ceval/submission_example.json",
    "chars": 4617,
    "preview": "{\n    \"chinese_language_and_literature\": {\n        \"0\": \"A\",\n        \"1\": \"B\",\n        \"2\": \"B\"\n    },\n    \"clinical_med"
  },
  {
    "path": "src/resources/eval/mmlu/subject_mapping.json",
    "chars": 4013,
    "preview": "{\n  \"moral_scenarios\": [\n    \"Moral Scenarios\",\n    \"Humanities\"\n  ],\n  \"us_foreign_policy\": [\n    \"US Foreign Policy\",\n"
  },
  {
    "path": "src/resources/models/baichuan-13B-base/config.json",
    "chars": 715,
    "preview": "{\n  \"_from_model_config\": true,\n  \"architectures\": [\n    \"BaichuanForCausalLM\"\n  ],\n  \"auto_map\": {\n    \"AutoConfig\": \"c"
  },
  {
    "path": "src/resources/models/baichuan-13B-base/configuration_baichuan.py",
    "chars": 1490,
    "preview": "# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.\n\nfrom transformers.configuration_utils impor"
  },
  {
    "path": "src/resources/models/baichuan-13B-base/modeling_baichuan.py",
    "chars": 21761,
    "preview": "# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.\n\nimport math\nfrom typing import List, Option"
  },
  {
    "path": "src/resources/models/baichuan-13B-base/tokenization_baichuan.py",
    "chars": 8720,
    "preview": "# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.\n\nimport os\nfrom shutil import copyfile\nfrom "
  },
  {
    "path": "src/resources/models/baichuan-13B-base/tokenizer_config.json",
    "chars": 953,
    "preview": "{\n  \"add_bos_token\": true,\n  \"add_eos_token\": false,\n  \"auto_map\": {\n    \"AutoTokenizer\": [\n      \"tokenization_baichuan"
  },
  {
    "path": "src/resources/models/baichuan-13B-chat/config.json",
    "chars": 680,
    "preview": "{\n  \"_from_model_config\": true,\n  \"architectures\": [\n    \"BaichuanForCausalLM\"\n  ],\n  \"auto_map\": {\n    \"AutoConfig\": \"c"
  },
  {
    "path": "src/resources/models/baichuan-13B-chat/configuration_baichuan.py",
    "chars": 1490,
    "preview": "# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.\n\nfrom transformers.configuration_utils impor"
  },
  {
    "path": "src/resources/models/baichuan-13B-chat/modeling_baichuan.py",
    "chars": 21763,
    "preview": "# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.\n\nimport math\nfrom typing import List, Option"
  },
  {
    "path": "src/resources/models/baichuan-13B-chat/tokenization_baichuan.py",
    "chars": 8720,
    "preview": "# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.\n\nimport os\nfrom shutil import copyfile\nfrom "
  },
  {
    "path": "src/resources/models/baichuan-13B-chat/tokenizer_config.json",
    "chars": 953,
    "preview": "{\n  \"add_bos_token\": true,\n  \"add_eos_token\": false,\n  \"auto_map\": {\n    \"AutoTokenizer\": [\n      \"tokenization_baichuan"
  },
  {
    "path": "src/resources/models/baichuan-7B/config.json",
    "chars": 656,
    "preview": "{\n  \"architectures\": [\n    \"BaiChuanForCausalLM\"\n  ],\n  \"auto_map\": {\n    \"AutoConfig\": \"configuration_baichuan.BaiChuan"
  },
  {
    "path": "src/resources/models/baichuan-7B/configuration_baichuan.py",
    "chars": 2345,
    "preview": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on"
  },
  {
    "path": "src/resources/models/baichuan-7B/modeling_baichuan.py",
    "chars": 29233,
    "preview": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on"
  },
  {
    "path": "src/resources/models/baichuan-7B/tokenization_baichuan.py",
    "chars": 9574,
    "preview": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on"
  },
  {
    "path": "src/resources/models/baichuan-7B/tokenizer_config.json",
    "chars": 801,
    "preview": "{\n  \"auto_map\": {\n    \"AutoTokenizer\": [\"tokenization_baichuan.BaiChuanTokenizer\", null]\n  },\n  \"add_bos_token\": true,\n "
  },
  {
    "path": "src/resources/models/chatglm-6B/README.md",
    "chars": 2852,
    "preview": "---\nlanguage:\n- zh\n- en\ntags:\n- glm\n- chatglm\n- chatgpt\n---\n\nLink to github: [here](https://github.com/sunzeyeah/RLHF)\n\n"
  },
  {
    "path": "src/resources/models/chatglm-6B/config.json",
    "chars": 725,
    "preview": "{\n  \"_name_or_path\": \"sunzeyeah/chatglm-6B\",\n  \"architectures\": [\n    \"ChatGLMModel\"\n  ],\n  \"auto_map\": {\n    \"AutoConfi"
  },
  {
    "path": "src/resources/models/chatglm-6B/configuration_chatglm.py",
    "chars": 4118,
    "preview": "\"\"\" ChatGLM model configuration \"\"\"\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.uti"
  },
  {
    "path": "src/resources/models/chatglm-6B/modeling_chatglm.py",
    "chars": 56105,
    "preview": "\"\"\" PyTorch ChatGLM model. \"\"\"\n\nimport math\nimport copy\nimport os\nimport warnings\nimport re\nimport sys\n\nimport torch\nimp"
  },
  {
    "path": "src/resources/models/chatglm-6B/quantization.py",
    "chars": 15053,
    "preview": "from torch.nn import Linear\nfrom torch.nn.parameter import Parameter\n\nimport bz2\nimport torch\nimport base64\nimport ctype"
  },
  {
    "path": "src/resources/models/chatglm-6B/tokenization_chatglm.py",
    "chars": 14195,
    "preview": "\"\"\"Tokenization classes for ChatGLM.\"\"\"\nimport sys\nimport unicodedata\nfrom typing import List, Optional, Union\nfrom func"
  },
  {
    "path": "src/resources/models/chatglm-6B/tokenizer_config.json",
    "chars": 416,
    "preview": "{\n  \"name_or_path\": \"THUDM/chatglm-6b\",\n  \"bos_token\": \"<sop>\",\n  \"eop_token\": \"<eop>\",\n  \"eos_token\": \"</s>\",\n  \"gmask_"
  },
  {
    "path": "src/resources/models/chatglm2-6B/config.json",
    "chars": 1103,
    "preview": "{\n  \"model_type\": \"chatglm\",\n  \"architectures\": [\n    \"ChatGLMModel\"\n  ],\n  \"auto_map\": {\n    \"AutoConfig\": \"configurati"
  },
  {
    "path": "src/resources/models/chatglm2-6B/configuration_chatglm.py",
    "chars": 2057,
    "preview": "from transformers import PretrainedConfig\n\n\nclass ChatGLMConfig(PretrainedConfig):\n    def __init__(\n        self,\n     "
  },
  {
    "path": "src/resources/models/chatglm2-6B/modeling_chatglm.py",
    "chars": 48154,
    "preview": "\"\"\" PyTorch ChatGLM model. \"\"\"\n\nimport math\nimport copy\nimport warnings\nimport re\nimport sys\n\nimport torch\nimport torch."
  },
  {
    "path": "src/resources/models/chatglm2-6B/quantization.py",
    "chars": 14692,
    "preview": "from torch.nn import Linear\nfrom torch.nn.parameter import Parameter\n\nimport bz2\nimport torch\nimport base64\nimport ctype"
  },
  {
    "path": "src/resources/models/chatglm2-6B/tokenization_chatglm.py",
    "chars": 10557,
    "preview": "import os\nimport torch\nfrom typing import List, Optional, Union, Dict\nfrom sentencepiece import SentencePieceProcessor\nf"
  },
  {
    "path": "src/resources/models/chatglm2-6B/tokenizer_config.json",
    "chars": 243,
    "preview": "{\n  \"name_or_path\": \"THUDM/chatglm-6b\",\n  \"remove_space\": false,\n  \"do_lower_case\": false,\n  \"tokenizer_class\": \"ChatGLM"
  },
  {
    "path": "src/resources/models/chatglm3-6B/config.json",
    "chars": 1320,
    "preview": "{\n  \"_name_or_path\": \"THUDM/chatglm3-6b\",\n  \"model_type\": \"chatglm\",\n  \"architectures\": [\n    \"ChatGLMModel\"\n  ],\n  \"aut"
  },
  {
    "path": "src/resources/models/chatglm3-6B/configuration_chatglm.py",
    "chars": 2332,
    "preview": "from transformers import PretrainedConfig\n\n\nclass ChatGLMConfig(PretrainedConfig):\n    model_type = \"chatglm\"\n    def __"
  },
  {
    "path": "src/resources/models/chatglm3-6B/modeling_chatglm.py",
    "chars": 55604,
    "preview": "\"\"\" PyTorch ChatGLM model. \"\"\"\n\nimport math\nimport copy\nimport warnings\nimport re\nimport sys\n\nimport torch\nimport torch."
  },
  {
    "path": "src/resources/models/chatglm3-6B/quantization.py",
    "chars": 14692,
    "preview": "from torch.nn import Linear\nfrom torch.nn.parameter import Parameter\n\nimport bz2\nimport torch\nimport base64\nimport ctype"
  },
  {
    "path": "src/resources/models/chatglm3-6B/tokenization_chatglm.py",
    "chars": 11313,
    "preview": "import json\nimport os\nimport torch\nfrom typing import List, Optional, Union, Dict\nfrom sentencepiece import SentencePiec"
  },
  {
    "path": "src/resources/models/chatglm3-6B/tokenizer_config.json",
    "chars": 244,
    "preview": "{\n  \"name_or_path\": \"THUDM/chatglm3-6b\",\n  \"remove_space\": false,\n  \"do_lower_case\": false,\n  \"tokenizer_class\": \"ChatGL"
  },
  {
    "path": "src/resources/models/glm-10B-chinese/README.md",
    "chars": 3152,
    "preview": "---\nlanguage:\n- zh\ntags:\n- glm\n- chatgpt\n---\n\nLink to github: [here](https://github.com/sunzeyeah/RLHF)\n\n---\n\n本仓库由[THUDM"
  },
  {
    "path": "src/resources/models/glm-10B-chinese/added_tokens.json",
    "chars": 223,
    "preview": "{\n  \"<|endoftext|>\": 50000,\n  \"[SEP]\": 50001,\n  \"[CLS]\": 50002,\n  \"[MASK]\": 50003,\n  \"[UNUSED1]\": 50004,\n  \"[UNUSED2]\": "
  },
  {
    "path": "src/resources/models/glm-10B-chinese/config.json",
    "chars": 1003,
    "preview": "{\n  \"model_name_or_path\": \"sunzeyeah/glm-10B-chinese\",\n  \"architectures\": [\n    \"GLMModel\"\n  ],\n  \"attention_dropout_pro"
  },
  {
    "path": "src/resources/models/glm-10B-chinese/configuration_glm.py",
    "chars": 6246,
    "preview": "# coding=utf-8\n# Copyright 2022 shunxing1234 and The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the "
  },
  {
    "path": "src/resources/models/glm-10B-chinese/modeling_glm.py",
    "chars": 39806,
    "preview": "# coding=utf-8\n# Copyright 2022 shunxing1234 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apac"
  },
  {
    "path": "src/resources/models/glm-10B-chinese/tokenization_glm.py",
    "chars": 16387,
    "preview": "import os\nfrom typing import Optional, Tuple, List, Union\nfrom shutil import copyfile\nimport torch\n\nfrom transformers im"
  },
  {
    "path": "src/resources/models/glm-10B-chinese/tokenizer_config.json",
    "chars": 504,
    "preview": "{\n  \"name_or_path\": \"BAAI/glm-large-chinese\",\n  \"eos_token\": \"<|endoftext|>\",\n  \"pad_token\": \"<|endoftext|>\",\n  \"cls_tok"
  },
  {
    "path": "src/resources/models/glm-350M-chinese/README.md",
    "chars": 3163,
    "preview": "---\nlanguage:\n- zh\ntags:\n- glm\n- chatgpt\n---\n\nLink to github: [here](https://github.com/sunzeyeah/RLHF)\n\n---\n\n本仓库由[THUDM"
  },
  {
    "path": "src/resources/models/glm-350M-chinese/added_tokens.json",
    "chars": 223,
    "preview": "{\n  \"<|endoftext|>\": 50000,\n  \"[SEP]\": 50001,\n  \"[CLS]\": 50002,\n  \"[MASK]\": 50003,\n  \"[UNUSED1]\": 50004,\n  \"[UNUSED2]\": "
  },
  {
    "path": "src/resources/models/glm-350M-chinese/config.json",
    "chars": 961,
    "preview": "{\n  \"model_name_or_path\": \"sunzeyeah/glm-350M-chinese\",\n  \"architectures\": [\n    \"GLMModel\"\n  ],\n  \"attention_dropout_pr"
  },
  {
    "path": "src/resources/models/glm-350M-chinese/configuration_glm.py",
    "chars": 6239,
    "preview": "# coding=utf-8\n# Copyright 2022 shunxing1234 and The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the "
  },
  {
    "path": "src/resources/models/glm-350M-chinese/modeling_glm.py",
    "chars": 36465,
    "preview": "# coding=utf-8\n# Copyright 2022 shunxing1234 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apac"
  },
  {
    "path": "src/resources/models/glm-350M-chinese/tokenization_glm.py",
    "chars": 16537,
    "preview": "import os\nfrom typing import Optional, Tuple, List, Union\nfrom shutil import copyfile\nimport torch\n\nfrom transformers im"
  },
  {
    "path": "src/resources/models/glm-350M-chinese/tokenizer_config.json",
    "chars": 504,
    "preview": "{\n  \"name_or_path\": \"BAAI/glm-large-chinese\",\n  \"eos_token\": \"<|endoftext|>\",\n  \"pad_token\": \"<|endoftext|>\",\n  \"cls_tok"
  },
  {
    "path": "src/resources/models/pangu-13B/config.json",
    "chars": 950,
    "preview": "{\n  \"model_name_or_path\": \"sunzeyeah/pangu-13B\",\n  \"activation_function\": \"gelu\",\n  \"architectures\": [\n    \"GPTPanguForC"
  },
  {
    "path": "src/resources/models/pangu-13B/configuration_gptpangu.py",
    "chars": 1734,
    "preview": "from transformers.configuration_utils import PretrainedConfig\n\n\nclass GPTPanguConfig(PretrainedConfig):\n    model_type ="
  },
  {
    "path": "src/resources/models/pangu-13B/modeling_gptpangu.py",
    "chars": 22409,
    "preview": "\"\"\"PyTorch PanguAlpha GPT2 Model\"\"\"\n# from .configuration_gptpangu import GPTPanguConfig\n\nfrom typing import Tuple\nimpor"
  },
  {
    "path": "src/resources/models/pangu-13B/tokenization_gptpangu.py",
    "chars": 5037,
    "preview": "import os\nimport torch\nimport sentencepiece\nimport jieba\nimport numpy as np\n\nfrom transformers.tokenization_utils import"
  },
  {
    "path": "src/resources/models/pangu-13B/tokenizer_config.json",
    "chars": 324,
    "preview": "{\n  \"eos_token\": \"<eot>\",\n  \"pad_token\": \"<pad>\",\n  \"unk_token\": \"<unk>\",\n  \"sep_token\": \"<sep>\",\n  \"bos_token\": \"<s>\",\n"
  },
  {
    "path": "src/resources/models/pangu-2.6B/config.json",
    "chars": 987,
    "preview": "{\n  \"model_name_or_path\": \"sunzeyeah/pangu-2.6B\",\n  \"activation_function\": \"gelu\",\n  \"architectures\": [\n    \"GPTPanguFor"
  },
  {
    "path": "src/resources/models/pangu-2.6B/configuration_gptpangu.py",
    "chars": 1734,
    "preview": "from transformers.configuration_utils import PretrainedConfig\n\n\nclass GPTPanguConfig(PretrainedConfig):\n    model_type ="
  },
  {
    "path": "src/resources/models/pangu-2.6B/modeling_gptpangu.py",
    "chars": 22409,
    "preview": "\"\"\"PyTorch PanguAlpha GPT2 Model\"\"\"\n# from .configuration_gptpangu import GPTPanguConfig\n\nfrom typing import Tuple\nimpor"
  },
  {
    "path": "src/resources/models/pangu-2.6B/tokenization_gptpangu.py",
    "chars": 5037,
    "preview": "import os\nimport torch\nimport sentencepiece\nimport jieba\nimport numpy as np\n\nfrom transformers.tokenization_utils import"
  },
  {
    "path": "src/resources/models/pangu-2.6B/tokenizer_config.json",
    "chars": 324,
    "preview": "{\n  \"eos_token\": \"<eot>\",\n  \"pad_token\": \"<pad>\",\n  \"unk_token\": \"<unk>\",\n  \"sep_token\": \"<sep>\",\n  \"bos_token\": \"<s>\",\n"
  },
  {
    "path": "src/resources/models/pangu-350M/README-Reward.md",
    "chars": 1215,
    "preview": "Link to github: [here](https://github.com/sunzeyeah/RLHF)\n\n---\n\n# Model Description\n\nPangu-α is proposed by a joint tech"
  },
  {
    "path": "src/resources/models/pangu-350M/README-SFT.md",
    "chars": 2326,
    "preview": "Link to github: [here](https://github.com/sunzeyeah/RLHF)\n\n---\n\n# Model Description\n\nPangu-α is proposed by a joint tech"
  },
  {
    "path": "src/resources/models/pangu-350M/README.md",
    "chars": 1270,
    "preview": "---\nlanguage:\n- zh\ntags:\n- pangu\n- chatgpt\n---\n\nLink to github: [here](https://github.com/sunzeyeah/RLHF)\n\n---\n\n\n# Model"
  },
  {
    "path": "src/resources/models/pangu-350M/config.json",
    "chars": 930,
    "preview": "{\n  \"model_name_or_path\": \"sunzeyeah/pangu-350M\",\n  \"activation_function\": \"gelu\",\n  \"architectures\": [\n    \"GPTPanguFor"
  },
  {
    "path": "src/resources/models/pangu-350M/configuration_gptpangu.py",
    "chars": 1873,
    "preview": "from transformers.configuration_utils import PretrainedConfig\n\n\nclass GPTPanguConfig(PretrainedConfig):\n    model_type ="
  },
  {
    "path": "src/resources/models/pangu-350M/modeling_gptpangu.py",
    "chars": 22409,
    "preview": "\"\"\"PyTorch PanguAlpha GPT2 Model\"\"\"\n# from .configuration_gptpangu import GPTPanguConfig\n\nfrom typing import Tuple\nimpor"
  },
  {
    "path": "src/resources/models/pangu-350M/tokenization_gptpangu.py",
    "chars": 5037,
    "preview": "import os\nimport torch\nimport sentencepiece\nimport jieba\nimport numpy as np\n\nfrom transformers.tokenization_utils import"
  },
  {
    "path": "src/resources/models/pangu-350M/tokenizer_config.json",
    "chars": 324,
    "preview": "{\n  \"eos_token\": \"<eot>\",\n  \"pad_token\": \"<pad>\",\n  \"unk_token\": \"<unk>\",\n  \"sep_token\": \"<sep>\",\n  \"bos_token\": \"<s>\",\n"
  },
  {
    "path": "src/resources/models/qwen-7B-chat/config.json",
    "chars": 1110,
    "preview": "{\n  \"activation\": \"swiglu\",\n  \"apply_residual_connection_post_layernorm\": false,\n  \"architectures\": [\n    \"QWenLMHeadMod"
  },
  {
    "path": "src/resources/models/qwen-7B-chat/configuration_qwen.py",
    "chars": 2329,
    "preview": "# Copyright (c) Alibaba Cloud.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the roo"
  },
  {
    "path": "src/resources/models/qwen-7B-chat/generation_config.json",
    "chars": 364,
    "preview": "{\n  \"chat_format\": \"chatml\",\n  \"decay_bound\": 0.0,\n  \"decay_factor\": 1.0,\n  \"eos_token_id\": 151643,\n  \"factual_nucleus_s"
  },
  {
    "path": "src/resources/models/qwen-7B-chat/modeling_qwen.py",
    "chars": 37463,
    "preview": "# Copyright (c) Alibaba Cloud.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the roo"
  },
  {
    "path": "src/resources/models/qwen-7B-chat/qwen_generation_utils.py",
    "chars": 14519,
    "preview": "# Copyright (c) Alibaba Cloud.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the roo"
  },
  {
    "path": "src/resources/models/qwen-7B-chat/tokenization_qwen.py",
    "chars": 9368,
    "preview": "# Copyright (c) Alibaba Cloud.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the roo"
  },
  {
    "path": "src/resources/models/qwen-7B-chat/tokenizer_config.json",
    "chars": 552,
    "preview": "{\n  \"remove_space\": false,\n  \"do_lower_case\": false,\n  \"tokenizer_class\": \"QWenTokenizer\",\n  \"auto_map\": {\n    \"AutoToke"
  },
  {
    "path": "src/train_dpo.py",
    "chars": 13936,
    "preview": "import sys\n\nsys.path.insert(0, \"/root/autodl-tmp/Code/RLHF\")\nsys.path.insert(0, \"/mnt/sfevol775196/sunzeye273/Code/chatg"
  },
  {
    "path": "src/train_reward.py",
    "chars": 11016,
    "preview": "\nimport sys\nsys.path.insert(0, \"/mnt/sfevol775196/sunzeye273/Code/chatgpt\")\nsys.path.insert(0, \"/Users/zeyesun/Documents"
  },
  {
    "path": "src/train_rlhf.py",
    "chars": 21757,
    "preview": "\nimport sys\nsys.path.insert(0, \"/root/autodl-tmp/Code/RLHF\")\nsys.path.insert(0, \"/mnt/sfevol775196/sunzeye273/Code/chatg"
  },
  {
    "path": "src/train_rlhf_trlx.py",
    "chars": 16492,
    "preview": "\nimport sys\nsys.path.insert(0, \"/root/autodl-tmp/Code/RLHF\")\nsys.path.insert(0, \"/mnt/sfevol775196/sunzeye273/Code/chatg"
  },
  {
    "path": "src/train_sft.py",
    "chars": 17061,
    "preview": "\nimport sys\nsys.path.insert(0, \"/root/autodl-tmp/Code/RLHF\")\nsys.path.insert(0, \"/mnt/sfevol775196/sunzeye273/Code/chatg"
  },
  {
    "path": "src/train_sft_wo_trainer.py",
    "chars": 18752,
    "preview": "\nimport sys\nsys.path.insert(0, \"/root/autodl-tmp/Code/RLHF\")\nsys.path.insert(0, \"/mnt/sfevol775196/sunzeye273/Code/chatg"
  },
  {
    "path": "src/utils/__init__.py",
    "chars": 246,
    "preview": "\nfrom src.utils.config import *\nfrom src.utils.logger import logger\nfrom src.utils.nlp_utils import *\nfrom src.utils.fil"
  },
  {
    "path": "src/utils/config.py",
    "chars": 15501,
    "preview": "\nimport os\nimport yaml\n\nfrom copy import deepcopy\nfrom dataclasses import dataclass, field\nfrom typing import Any, Dict,"
  },
  {
    "path": "src/utils/file_utils.py",
    "chars": 3122,
    "preview": "\nimport math\nimport os\nimport random\nimport numpy as np\nimport torch\n\nfrom numbers import Number\nfrom pynvml import *\n\nf"
  },
  {
    "path": "src/utils/loading.py",
    "chars": 14397,
    "preview": "import types\nimport torch\nimport glob\nfrom typing import Callable, Dict, Tuple\n\n# Register load pipelines via module imp"
  },
  {
    "path": "src/utils/logger.py",
    "chars": 224,
    "preview": "\nimport os\nimport logging.config\n\nfrom src.utils.config import RESOURCE_PATH\n\nlogging.config.fileConfig(os.path.join(RES"
  },
  {
    "path": "src/utils/method_configs.py",
    "chars": 10813,
    "preview": "import sys\nimport torch\nimport torch.nn.functional as F\n\nfrom functools import reduce\nfrom dataclasses import dataclass\n"
  },
  {
    "path": "src/utils/modeling_utils.py",
    "chars": 35339,
    "preview": "\nimport functools\nfrom typing import Any, Dict, List, MutableMapping, Tuple, Union, Optional\n\nimport os\nimport subproces"
  },
  {
    "path": "src/utils/nlp_utils.py",
    "chars": 175,
    "preview": "\nimport re\n\nCLEAN_TEXT_PATTERN = re.compile(r\"[\\r\\n]\")\n\n\ndef clean_text(text):\n    if not isinstance(text, str):\n       "
  }
]

About this extraction

This page contains the full source code of the sunzeyeah/RLHF GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 135 files (1.4 MB), approximately 350.5k tokens, and a symbol index with 1225 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!