[
  {
    "path": "README.md",
    "content": "## Features\n\n提供3大功能：\n- LLM模型预训练：支持常见模型的预训练，包括：decoder结构（LLaMA、GPT）、encoder结构（GLM）\n- LLM模型评测：参考GPT类模型，基于ZeroShot和FewShot实现\n- ChatGPT模型训练pipeline：根据[Learning to Summarize from human feedback](https://arxiv.org/abs/2009.01325) ，实现3大流程: SFT、Reward Model和RLHF\n  - 支持RLHF阶段 (1) 联合优化reward和policy (2) 单独优化policy，冻结reward \n  - 支持[DPO](https://arxiv.org/abs/2305.18290)作为Reward+RLHF的替代方案，可显著降低显存占用，同时实现RL的效果\n\n## Setup\n\n### 1. Install deepspeed\n```bash\ngit clone https://github.com/microsoft/DeepSpeed.git\ncd deepspeed\nrm -rf build\nTORCH_CUDA_ARCH_LIST=\"7.0\" DS_BUILD_OPS=1 pip install -e . --global-option=\"build_ext\" --global-option=\"-j8\" --no-cache -v --disable-pip-version-check 2>&1 | tee build.log\n```\n如果想创建binary wheel，方便在其他机器上安装，可使用如下命令，会在```dist```目录生成类似可安装文件```deepspeed-0.3.13+8cd046f-cp38-cp38-linux_x86_64.whl```\n```bash\ngit clone https://github.com/microsoft/DeepSpeed.git\ncd deepspeed\nrm -rf build\nTORCH_CUDA_ARCH_LIST=\"7.0\" DS_BUILD_OPS=1 python setup.py build_ext -j8 bdist_wheel 2>&1 | tee build.log\n```\n**PS**：需要根据下图，调整```TORCH_CUDA_ARCH_LIST=\"7.0\"```为自己对应的NVIDIA GPU架构\n![image info](./images/torch_cuda_list.png \"torch_cuda_list\")\n\n或运行```torch.cuda.get_device_capability()```获取自己GPU的架构\n\n### 2. Install jieba\n在使用Pangu类模型的时候，其special_token格式为```<sep>```、```<pad>```等，而[tokenization_gptpangu.py](src/resources/models/pangu-350M/tokenization_gptpangu.py)中```tokenize()```函数会使用```jieba```进行分词。但直接```pip install jieba```，默认会将```<```和```>```直接切分开，使用```jieba.add_word(\"<sep>\")```也没有作用，因为```jieba```直接hardcode了会自动切分的token，其中就包括了```<```和```>```。 \n\n因此需要执行：\n```bash\ngit clone https://github.com/fxsjy/jieba.git\ncd jieba\n```\n将代码clone到本地，修改```jieba/__init__.py```中```re_han_default```的取值，具体改动如下：\n\n- 改动前：\n```python\nre_han_default = re.compile(\"([\\u4E00-\\u9FD5a-zA-Z0-9+#&\\._%\\-]+)\", re.U)\n```\n\n- 改动后：\n```python\nre_han_default = re.compile(\"([\\u4E00-\\u9FD5a-zA-Z0-9+#&\\._%\\-<>]+)\", re.U)\n```\n\n修改完成后使用```pip install .```进行本地编译安装，替换原有```jieba```。安装完成后，在代码中加入```jieba.add_word(\"<sep>\")```（该代码已加入[tokenization_gptpangu.py](src/resources/models/pangu-350M/tokenization_gptpangu.py)），即可解决将```<sep>```一类的special token切分为多个id的情况\n\n\n### 3. Install apex (Optional)\n```bash\ngit clone https://github.com/NVIDIA/apex\ncd apex\npip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" --no-cache -v --disable-pip-version-check .  2>&1 | tee build.log\n```\n如果想创建binary wheel，方便在其他机器上安装，可使用如下命令，会在```dist```目录生成类似可安装文件```apex-0.0.1+7150e20-cp38-cp38-linux_x86_64.whl```\n```bash\ngit clone https://github.com/NVIDIA/apex\ncd apex\npython setup.py --cpp_ext --cuda_ext bdist_wheel 2>&1 | tee build.log\n```\n\n\n## Data & Model Download\n\n### 1. 预训练模型下载\n\n| 模型      | size | huggingface地址 | 百度网盘地址  |  提取码      | \n| ----------- | ----------- | ----------- |  ----------- |  ----------- |\n| [Pangu-350M](https://openi.pcl.ac.cn/PCL-Platform.Intelligence/PanGu-Alpha)   | 659MB | [sunzeyeah/pangu-350M](https://huggingface.co/sunzeyeah/pangu-350M) | [Pangu-350M](https://pan.baidu.com/s/1IzgtW48S2PKyjxPPMe1rAQ) |  c5jj |\n| [Pangu-2.6B](https://openi.pcl.ac.cn/PCL-Platform.Intelligence/PanGu-Alpha)   | 9.8GB | [sunzeyeah/pangu-2_6B](https://huggingface.co/sunzeyeah/pangu-2_6B) | [Pangu-2.6B](https://pan.baidu.com/s/1Tzvja4q_LgQOwkWPQ4jShw)    | 2rad |\n| [Pangu-13B](https://openi.pcl.ac.cn/PCL-Platform.Intelligence/PanGu-Alpha)   | 23.6GB | [sunzeyeah/pangu-13B](https://huggingface.co/sunzeyeah/pangu-13B) | [Pangu-13B](https://pan.baidu.com/s/11fWAeYYKqI7pH0UiuJ5jEQ)    | u3dx |\n| [GLM-350M-chinese](https://github.com/THUDM/GLM) | 679MB | [sunzeyeah/glm-350M-chinese](https://huggingface.co/sunzeyeah/glm-350M-chinese) | [GLM-350M-chinese](https://pan.baidu.com/s/11Lef-E7Tsz5OGOueCpiqaA) | ii8e |\n| [GLM-10B-chinese](https://github.com/THUDM/GLM)   | 18.4G | [sunzeyeah/glm-10B-chinese](https://huggingface.co/sunzeyeah/glm-10B-chinese) | [GLM-10B-chinese](https://pan.baidu.com/s/1GuOefx42n_GzFfwnjoBltw) | fynj  |\n| [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B)   | 25.6G | [sunzeyeah/chatglm-6B](https://huggingface.co/sunzeyeah/chatglm-6B) | [ChatGLM-6B](https://pan.baidu.com/s/1OlpkMeQD6-LEpNFWx5E-mg) | uq1k |\n\n**PS**: 本repo提供的预训练模型下载中，\n- 对于pytorch_model\\*.bin\n  - 如果源文件已包括，则不做改动\n  - 如果源文件不包括，则根据其提供的checkpoint转换为pytorch_model\\*.bin\n- 其余文件可能相对原文件有改动，包括：modeling_\\*.py、tokenization_\\*.py、configuration_\\*.py、config.json和tokenizer.config\n\n### 2. 数据下载\n\n| 数据集      | size | huggingface地址 | 百度网盘地址  |  提取码      | \n| ----------- | ----------- | ----------- |  ----------- |  ----------- |\n| [CLUE Benchmark](https://github.com/CLUEbenchmark/CLUE)   | 500MB | | [CLUE Benchmark](https://pan.baidu.com/s/15F60nRbBd6d6UvyXdwbXQA) |  m6gt |\n| SFT & Reward Data  | 5GB | [sunzeyeah/chinese_chatgpt_corpus](https://huggingface.co/datasets/sunzeyeah/chinese_chatgpt_corpus) | [SFT & Reward Data](https://pan.baidu.com/s/1sl8PB-Dlt1xLIYczMODyRg) | ecyc |\n| [百科](https://github.com/brightmart/nlp_chinese_corpus)  | 652MB | | [baike_qa_2019](https://pan.baidu.com/s/1N6I-fvx6FLHwuxJuDLLA8g) | 7jad |\n| [知道问答](https://github.com/SophonPlus/ChineseNlpCorpus) | 847MB | | [zhidao](https://pan.baidu.com/s/1sjR3vABiMbdV1HkUt6kCKQ) | neds |\n| [对联](https://github.com/wb14123/couplet-dataset/releases/download/1.0/couplet.tar.gz)  | 221MB | | [couplets](https://pan.baidu.com/s/1oo6oaephPXpyU-sDd_37qg) | 54ey |\n| [古文](https://github.com/NiuTrans/Classical-Modern)  | 125MB | | [Classical & Modern](https://pan.baidu.com/s/1ZyGkYOgzT9ZEVnpSpB4kQg) | a4cr |\n| [古诗词](https://github.com/chinese-poetry/chinese-poetry)  | 87MB | | [chinese poetry](https://pan.baidu.com/s/13uvkA96PdKpKB7ZQ2GkXZQ) | 5zzj |\n| 微博新闻评论  | 522MB | | [weibo summary comments](https://pan.baidu.com/s/1h45O0q6gQl3LbH-NjzuRlw) | w0g1 |\n\n**PS**: SFT & Reward Data基于百科、知道问答、对联、古文、古诗词、微博新闻评论数据构造，可直接用于SFT和Reward阶段训练。详见[data_prepare.py](./src/data_prepare.py)\n\n\n## Usage\n\n### 1. LLM模型预训练\n对开源LLM进行增量预训练，基于deepspeed实现。目前支持2类模型架构：\n- decoder结构：LLaMA、Baichuan、Pangu\n- encoder结构：GLM、ChatGLM\n\n```bash\ncd examples\nbash pretrain.sh\n```\n\n### 2. LLM模型评测\n对开源中文LLM进行ZeroShot、OneShot或FewShot的评测。详见[eval_pretrain.py](./src/eval_pretrain.py) 和 [data.py](src/data/data.py)。\n\n目前支持的评测任务：\n- [C-Eval](https://github.com/SJTU-LIT/ceval)\n- [MMLU](https://github.com/hendrycks/test)\n- [CLUEBenchmark](https://github.com/CLUEbenchmark/CLUE) ：评测方法和prompt模板参考[Pangu-alpha论文](https://arxiv.org/abs/2104.12369) \n\n目前支持的开源模型: \n- LLaMA及相关衍生模型 \n- ChatGLM（1和2）\n- Baichuan\n- Qwen\n- Pangu\n- GLM\n\n\n```bash\ncd examples\nbash eval_pretrain.sh\n```\n\n### 3. SFT\n使用开源LLM + SFT&Reward数据进行SFT训练\n```bash\ncd examples\nbash train_sft.sh\n```\n### 4. Reward Model\n使用SFT模型 + SFT&Reward数据进行Reward模型训练\n```bash\ncd examples\nbash train_reward.sh\n```\n\n### 5. RLHF\n利用PPO算法和Reward Model，进一步更新SFT模型。基于开源框架[DeepSpeedChat](https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat) 实现\n```bash\ncd examples\nbash train_rlhf.sh\n```\n\n### 6. DPO\n利用DPO算法替代Reward+RLHF的pipeline，免去训练Reward模型，同时达到RL训练的效果，该方法可显著降低显存占用。基于开源框架[trl](https://github.com/huggingface/trl) 实现\n```bash\ncd examples\nbash train_dpo.sh\n```\n\n\n## Results\n\n### 1. LLM模型评测\n\n<details>\n<summary><b>C-Eval 5-shot测试集(test)结果</b></summary>\n<table>\n    <tr>  <td>Model</td>  <td>Avg</td>  <td>Avg(Hard)</td> <td>STEM</td> <td>Social Science</td> <td>Humanities</td> <td>Other</td> </tr>\n    <tr>  <td>Baichuan2-13B-Chat</td>  <td style=\"color:red\"><b>56.30</b></td>  <td>34.20</td>  <td style=\"color:red\"><b>48.20</b></td>  <td style=\"color:red\"><b>70.00</b></td>  <td style=\"color:red\"><b>60.50</b></td>  <td>54.20</td> </tr>    \n    <tr>  <td>xverse-13B</td>  <td>55.30</td>  <td>32.50</td>  <td>45.90</td>  <td>66.70</td>  <td>59.50</td>  <td style=\"color:red\"><b>57.60</b></td> </tr>    \n    <tr>  <td>Qwen-7B-Chat</td>  <td>54.70</td>  <td>35.40</td>  <td>47.90</td>  <td>68.30</td>  <td>58.70</td>  <td>50.00</td> </tr>    \n    <tr>  <td>Baichuan-13B-Base</td>  <td >53.70</td>  <td style=\"color:red\"><b>35.60</b></td>  <td>46.80</td>  <td>65.80</td>  <td>58.00</td>  <td>50.80</td> </tr>    \n    <tr>  <td>Baichuan2-7B-Chat</td>  <td>52.50</td>  <td>33.80</td>  <td>45.70</td>  <td>64.20</td>  <td>56.60</td>  <td>50.20</td> </tr>\n    <tr>  <td>ChatGLM2-6B</td>  <td>51.20</td>  <td>33.40</td>  <td>46.90</td>  <td>63.00</td>  <td>51.60</td>  <td>47.70</td> </tr>\n    <tr>  <td>Baichuan-13B-Chat</td>  <td>47.90</td>  <td>31.50</td>  <td>41.40</td>  <td>56.80</td>  <td>53.00</td>  <td>46.50</td> </tr>\n    <tr>  <td>Baichuan-7B</td>  <td>44.20</td>  <td>31.70</td>  <td>39.20</td>  <td>53.30</td>  <td>47.30</td>  <td>41.90</td> </tr>\n    <tr>  <td>Ziya-LLaMA-13B-v1.1</td>  <td>40.10</td>  <td>30.30</td>  <td>35.80</td>  <td>47.30</td>  <td>42.80</td>  <td>38.50</td> </tr>\n    <tr>  <td>ChatGLM1.1-6B</td>  <td>38.10</td>  <td>28.60</td>  <td>33.60</td>  <td>46.70</td>  <td>40.90</td>  <td>35.70</td> </tr>\n    <tr>  <td>AtomGPT-13B-56k</td>  <td>37.60</td>  <td>25.30</td>  <td>32.00</td>  <td>44.70</td>  <td>42.80</td>  <td>36.10</td> </tr>\n    <tr>  <td>LLaMA2-13B-chat</td>  <td>37.10</td>  <td>29.30</td>  <td>34.60</td>  <td>43.60</td>  <td>35.90</td>  <td>37.00</td> </tr>\n    <tr>  <td>ChatGLM-6B</td>  <td>36.30</td>  <td>27.20</td>  <td>32.90</td>  <td>42.80</td>  <td>38.10</td>  <td>34.90</td> </tr>\n    <tr>  <td>LLaMA-30B</td>  <td>35.90</td>  <td>29.90</td>  <td>34.40</td>  <td>42.40</td>  <td>33.30</td>  <td>35.60</td> </tr>\n    <tr>  <td>LLaMA2-7B-chat</td>  <td>33.50</td>  <td>27.30</td>  <td>31.60</td>  <td>38.10</td>  <td>33.80</td>  <td>32.70</td> </tr>\n    <tr>  <td>Ziya-LLaMA-13B-Pretrain-v1</td>  <td>31.10</td>  <td>22.20</td>  <td>27.40</td>  <td>36.50</td>  <td>33.80</td>  <td>30.40</td> </tr>\n    <tr>  <td>LLaMA-13B</td>  <td>29.8</td>  <td>24.20</td>  <td>28.40</td>  <td>33.70</td>  <td>29.60</td>  <td>29.00</td> </tr>\n    <tr>  <td>LLaMA-7B</td>  <td>26.80</td>  <td>26.70</td>  <td>26.20</td>  <td>27.60</td>  <td>25.70</td>  <td>28.10</td> </tr>\n</table>\n</details>\n\n<details>\n<summary><b>MMLU 5-shot测试集(test)结果</b></summary>\n<table>\n    <tr>  <td>Model</td>  <td>Avg</td>  <td>STEM</td> <td>Social Science</td> <td>Humanities</td> <td>Other</td> </tr>\n    <tr>  <td>Baichuan2-13B-Chat</td>  <td style=\"color:red\"><b>56.90</b></td>  <td style=\"color:red\"><b>47.28</b></td>  <td style=\"color:red\"><b>66.23</b></td>  <td>52.90</td>  <td style=\"color:red\"><b>63.50</b></td> </tr>\n    <tr>  <td>LLaMA-30B</td>  <td>56.33</td>  <td>44.68</td>  <td>65.64</td>  <td style=\"color:red\"><b>54.60</b></td>  <td>61.57</td> </tr>\n    <tr>  <td>xverse-13B</td>  <td>55.24</td>  <td>45.60</td>  <td>64.51</td>  <td>50.32</td>  <td>63.27</td> </tr>\n    <tr>  <td>Qwen-7B-Chat</td>  <td>54.13</td>  <td>41.76</td>  <td>63.43</td>  <td>50.81</td>  <td>62.50</td> </tr>\n    <tr>  <td>LLaMA2-13B-chat</td>  <td>53.98</td>  <td>44.52</td>  <td>63.40</td>  <td>49.37</td>  <td>61.21</td> </tr>\n    <tr>  <td>Baichuan-13B-Base</td>  <td>53.46</td>  <td>43.86</td>  <td>63.14</td>  <td>49.73</td>  <td>59.28</td> </tr>\n    <tr>  <td>Baichuan2-7B-Chat</td>  <td>53.11</td>  <td>43.51</td>  <td>62.26</td>  <td>49.58</td>  <td>59.12</td> </tr>\n    <tr>  <td>Baichuan-13B-Chat</td>  <td>51.12</td>  <td>41.61</td>  <td>59.11</td>  <td>47.52</td>  <td>58.31</td> </tr>\n    <tr>  <td>Ziya-LLaMA-13B-v1.1</td>  <td>51.06</td>  <td>41.89</td>  <td>57.71</td>  <td>49.22</td>  <td>56.54</td> </tr>\n    <tr>  <td>LLaMA2-7B-chat</td>  <td>48.10</td>  <td>39.64</td>  <td>56.28</td>  <td>43.61</td>  <td>55.39</td> </tr>\n    <tr>  <td>LLaMA-13B</td>  <td>46.51</td>  <td>37.23</td>  <td>52.71</td>  <td>44.35</td>  <td>53.04</td> </tr>\n    <tr>  <td>ChatGLM2-6B</td>  <td>45.83</td>  <td>38.75</td>  <td>52.06</td>  <td>43.20</td>  <td>50.82</td> </tr>\n    <tr>  <td>AtomGPT-13B-56k</td>  <td>42.75</td>  <td>36.02</td>  <td>49.04</td>  <td>38.80</td>  <td>49.30</td> </tr>\n    <tr>  <td>Baichuan-7B</td>  <td>41.96</td>  <td>36.63</td>  <td>47.77</td>  <td>37.55</td>  <td>48.31</td> </tr>\n    <tr>  <td>Ziya-LLaMA-13B-Pretrain-v1</td>  <td>41.61</td>  <td>33.61</td>  <td>46.01</td>  <td>39.85</td>  <td>48.05</td> </tr>\n    <tr>  <td>ChatGLM1.1-6B</td>  <td>40.07</td>  <td>32.95</td>  <td>44.55</td>  <td>39.23</td>  <td>44.12</td> </tr>\n    <tr>  <td>ChatGLM-6B</td>  <td>37.87</td>  <td>32.41</td>  <td>43.80</td>  <td>35.60</td>  <td>41.00</td> </tr>\n    <tr>  <td>LLaMA-7B</td>  <td>28.53</td>  <td>26.10</td>  <td>28.76</td>  <td>28.52</td>  <td>24.81</td> </tr>\n</table>\n</details>\n\n<details>\n<summary><b>CLUEBenchmark 验证集(dev.json)结果</b></summary>\n\n<table>\n    <tr>  <td rowspan=\"2\">Dataset</td>  <td rowspan=\"2\">Method</td>  <td rowspan=\"2\">Metrics</td>  <td rowspan=\"2\">Task Type</td>  <td colspan=\"5\" style=\"text-align:center\">Zero-shot</td>  <td colspan=\"5\" style=\"text-align:center\">Few-shot</td> </tr>\n    <tr>  <td>GLM-350M-chinese</td>  <td>Pangu-350M</td>  <td>Pangu-2.6B</td>  <td>GLM-10B-chinese</td>  <td>Pangu-13B</td>  <td>GLM-350M-chinese</td>  <td>Pangu-350M</td>  <td>Pangu-2.6B</td>  <td>GLM-10B-chinese</td>  <td>Pangu-13B</td> </tr>\n    <tr>  <td>OCNLI</td>  <td>PPL</td>  <td>acc</td>  <td>NLI</td>  <td>0.3074</td>  <td style=\"color:red\"><b>0.3369</b></td>  <td>0.3061</td>  <td>0.3288</td>  <td>0.3301</td>  <td>0.3298</td>  <td>0.3352</td>  <td>0.3216</td>  <td></td>  <td></td> </tr>\n    <tr>  <td>CMNLI</td>  <td>PPL</td>  <td>acc</td>  <td>NLI</td>  <td>0.3279</td>  <td>0.3302</td>  <td>0.3310</td>  <td>0.3338</td>  <td style=\"color:red\"><b>0.3358</b></td>  <td>0.3356</td>  <td>0.3328</td>  <td>0.3300</td>  <td></td>  <td></td> </tr>\n    <tr>  <td>CHID</td>  <td>PPL</td>  <td>acc</td>  <td>Cloze(multi-choices)</td>  <td>0.0734</td>  <td>0.0916</td>  <td>0.0670</td>  <td>0.1016</td>  <td style=\"color:red\"><b>0.1018</b></td>  <td>0.0979</td>  <td>0.1007</td>  <td>0.0996</td>  <td></td>  <td></td> </tr>\n    <tr>  <td>CMRC2018</td>  <td>generation</td>  <td>f1</td>  <td>MRC</td>  <td>0.093</td>  <td>0.0979</td>  <td>0.1007</td>  <td style=\"color:red\"><b>0.1392</b></td>  <td>0.021</td>  <td>0.09345</td>  <td>0.097</td>  <td>0.1007</td>  <td></td>  <td></td> </tr>\n    <tr>  <td>CLUEWSC2020</td>  <td>PPL</td>  <td>acc</td>  <td>WSC</td>  <td>0.4934</td>  <td>0.5328</td>  <td style=\"color:red\"><b>0.5592</b></td>  <td>0.5131</td>  <td>0.4671</td>  <td>0.5526</td>  <td>0.4473</td>  <td>0.4671</td>  <td></td>  <td></td> </tr>\n    <tr>  <td>C3</td>  <td>PPL</td>  <td>acc</td>  <td>Common sense reasoning</td>  <td>0.2360</td>  <td>0.2426</td>  <td>0.2418</td>  <td style=\"color:red\"><b>0.2573</b></td>  <td>0.2567</td>  <td>0.2476</td>  <td>0.2559</td>  <td>0.2515</td>  <td></td>  <td></td> </tr>\n    <tr>  <td>AFQMC</td>  <td>PPL</td>  <td>acc</td>  <td>Text classification</td>  <td style=\"color:red\"><b>0.6306</b></td>  <td>0.4582</td>  <td>0.4914</td>  <td>0.4960</td>  <td>0.5000</td>  <td>0.4872</td>  <td>0.4993</td>  <td>0.5018</td>  <td></td>  <td></td> </tr>\n    <tr>  <td>CSL</td>  <td>PPL</td>  <td>acc</td>  <td>Text classification</td>  <td>0.4943</td>  <td>0.4913</td>  <td>0.4666</td>  <td style=\"color:red\"><b>0.5126</b></td>  <td>0.4996</td>  <td>0.5140</td>  <td>0.5036</td>  <td>0.4973</td>  <td></td>  <td></td> </tr>\n    <tr>  <td>IFLYTEK</td>  <td>PPL</td>  <td>acc</td>  <td>Text classification</td>  <td>0.1292</td>  <td style=\"color:red\"><b>0.3058</b></td>  <td>0.265</td>  <td>0.2620</td>  <td>0.2408</td>  <td>0.2539</td>  <td>0.2535</td>  <td>0.2524</td>  <td></td>  <td></td> </tr>\n    <tr>  <td>TNEWS</td>  <td>PPL</td>  <td>acc</td>  <td>Text classification</td>  <td>0.1582</td>  <td>0.2022</td>  <td>0.2449</td>  <td>0.2489</td>  <td style=\"color:red\"><b>0.2527</b></td>  <td>0.2555</td>  <td>0.2466</td>  <td>0.2494</td>  <td></td>  <td></td> </tr>\n</table>\n</details>\n\n### 2. SFT\n\n模型训练参数：\n\n| 模型 | 可训练参数量 | 数据量 | batch size | sequence length | 硬件 | 显存占用 | speed | Hours per epoch |\n| --- | --- | --- | --- | :---: | :---: | :---: | --- | --- |\n| GLM-350M-chinese | 355M | 5.4M | 4 | 512 | V100 16G | 13G | 3.7 s/iter | 88h |\n| Pangu-350M | 345M | 5.4M | 48 | 512 | A100 80G | 78G | 1.91 s/iter | 27.5h |\n| Pangu-2.6B | 2.6B | 5.4M |  8 | 512 | A100 80G | 79.4G | 9.61 s/iter | 116h |\n\nSFT模型下载：\n\n| 预训练模型  |  SFT数据集  | size | huggingface地址 | 百度网盘地址  |  提取码      | \n| ----------- | --- | ----------- | ----------- |  ----------- | ----------- |\n| Pangu-350M |  SFT & Reward Data | 1.3GB | [sunzeyeah/pangu-350M-sft](https://huggingface.co/sunzeyeah/pangu-350M-sft) | [Pangu-350M-SFT](https://pan.baidu.com/s/14nF63nAQz38jmplUC0EQ-g) | 0f8j |\n| Pangu-2.6B |  SFT & Reward Data | 9.8GB | [sunzeyeah/pangu-2_6B-sft](https://huggingface.co/sunzeyeah/pangu-2_6B-sft) | [Pangu-2.6B-SFT](https://pan.baidu.com/s/1Q2PKf0MnFsiFlNhU-6WIrg) | r13w |\n\n\n<details>\n<summary><b>Pangu-2.6B-SFT生成示例</b></summary>\n\n```\n# 生成参数\nmax_length = 512\nmax_length_generation = 100\ntop_k = 10\ntemperature = 0.8\ndo_sample=True\n```\n```\n[Input] 我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点<sep>回答:\n[Generated] 你爱他就多关心他,少在他面前表现出你的脆弱。这样他才更会爱你。\n\n\n[Input] 一次次买彩,一次次倾听那比分,一次次兴奋,一次次失落,一次次无归而返!最终还是任8~~~~<sep>回答:\n[Generated] 我也是这样,买彩的时间也不长,但是每次都是这样。。。。。。。\n\n[Input] 现代文:太祖很高兴地说:凭着这样的部下去作战,还担心不能取胜么?<sep>古文:\n[Generated] 太祖悦曰:此自是以去,何忧不胜?\n\n[Input] 现代文:十一年秋,丞相伯颜与阿术会师襄阳,遣阿剌罕率诸翼军攻郢、复诸州,十月夺取郢州之南门堡。<sep>古文:\n[Generated] 十一年秋,丞相伯颜与阿术会师襄阳,遣阿剌罕率诸翼军攻郢、复诸州,十月夺郢州之南门堡。\n\n[Input] 以孔⁇的风格,写一篇经书,题为中庸<sep>\n[Generated] 我闻昔者夫子之言也,曰:吾道一以贯之。其或行之或不行,吾未知之。是故君子慎其独也。吾未见好德如好色者也。好德如好色者,未之有也。好德如好色者,既已之有也。君子不以言举人,不以人废言。好德如好色者,其如命何!吾\n\n[Input] 以陆游的风格,写一首七言诗,题为晨起<sep>\n[Generated] 秋光满纸清⁇明,睡醒一江万马奔。自笑身如枯木鹤,何妨食饱作闲人。\n\n[Input] 上联:霜风几笔雕残画<sep>下联:\n[Generated] 云海千山叠翠峰\n\n[Input] 新闻内容:据悉,这名14岁的女生在Facebook上发布了戏谑考试的博文,其中使用了胡志明1946年反对法国殖民者的讲话,随后被越南当局停课一年。越南官员称,该学生歪曲了历史事件,侮辱了老师的情感。多数网友虽对越南当局做法不满,但都表示,至少他们还能上Facebook.<sep>摘要:越南女生Facebook上模仿胡志明讲话被停课<sep>评论:\n[Generated] 这女的,还真有勇气,我喜欢你!\n\n[Input] 新闻内容:习近平天津考察时勉励大学生志存高远、脚踏实地,转变择业观念,坚持从实际出发,勇于到基层一线和艰苦地方去,把人生的路一步步走稳走实,善于在平凡岗位上创造不平凡的业绩<sep>摘要:习近平勉励大学生转变择业观念勇于到基层和艰苦地方去<sep>评论:\n[Generated] 现在的社会,很多年轻人都是在父母的安排下,去了不同的单位,在不同的地方工作,在不同的城市生活,这就导致了年轻人的思维方式,价值观都不一样,这也导致了年轻人的成长道路,也有很多的不安全感,这也是很正常的事情,年轻人需要正确的引导,正确的引导,也需要正确的引导!\n\n[Input] 问题:劳动仲裁申请法院强制执行,法院强制执行需要多长时间<sep>回答:\n[Generated] 一般申请法院强制执行需要3-6个月时间,具体需要看法院的执行措施。\n\n[Input] 问题:银行贷款50万,30年!年利率是4.41怎么算一个月还多少钱<sep>回答:\n[Generated] 若申请的是招行贷款,要计算贷款利息或者每月还款金额,需要知道贷款本金,贷款期限,还款方式,贷款年利率。若上述信息都确认到,可以通过我行贷款计算器尝试计算,登录招行官网右下方找到“理财计算器”--“个人贷款计算器”可以计算。\n```\n</details>\n\n\n### 3. Reward Model\n\n模型训练参数：\n\n| 模型 | 可训练参数量 | 数据量 | batch size | sequence length | 硬件 | 显存占用 | speed | Hours per epoch |\n| --- | --- | --- | --- | :---: | :---: | :---: | --- | --- |\n| Pangu-350M | 131M | 12M | 32 | 512 | A100 80G | 72.6G | 1.91 s/iter | 105h |\n| Pangu-2.6B | 815M | 12M |  8 | 512 | A100 80G | 80.7G |  | 423h |\n\n\nReward模型下载：\n\n| 预训练模型  |  SFT数据集  | size | huggingface地址 | 百度网盘地址  |  提取码      | \n| ----------- | --- | ----------- | ----------- |  ----------- | ----------- |\n| Pangu-350M |  SFT & Reward Data | 1.3GB | [sunzeyeah/pangu-350M-reward](https://huggingface.co/sunzeyeah/pangu-350M-reward) | [Pangu-350M-Reward](https://pan.baidu.com/s/1wC3w78t7pVn0Xn5tJHy06A) | 4gju |\n\n\n### 4. RLHF\n\nTo be updated\n\n### 5. DeepSpeed实验\n\n为验证不同预训练模型使用deepspeed的训练效率是否能达到官方宣称的效果（加速、节省GPU等），进行了benchmarking\n- 实验场景：SFT阶段训练\n- 实验参数：```max_sequence_length=512```\n\n<details>\n<summary><b>DeepSpeed实验结果</b></summary>\n<table>\n   <tr> <td>模型</td> <td>数据</td>  <td>整体耗时/epoch</td>  <td>单条样本耗时</td>  <td>内存使用量</td>  <td>显存使用量</td>  <td>GPU型号和数量</td> <td>fp16</td> <td>bf16</td> <td>deepspeed stage</td> <td>offload optimizer</td> <td>pin memory</td> <td>offloard param</td> <td>overlap comm</td> <td>allgather bucket size</td> <td>stage3 max live parameters</td> <td>batch size</td> <td>gradient accumulation steps</td> <td>gradient checkpointing</td> <td>model half</td> </tr>\n   <tr> <td rowspan=\"11\">T5-large</td> <td rowspan=\"11\">wmt16-en-ro, 共计61万条样本</td> <td>43h</td>  <td>0.5s/it</td>  <td>7.1G</td>  <td>1*14529MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>-</td>  <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>\n   <tr> <td>152h</td>  <td>1.78s/it</td>  <td>38.26G</td>  <td>1*11663MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>2e8</td> <td>-</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>\n   <tr> <td>250h</td>  <td>2.95s/it</td>  <td>38.74G</td>  <td>1*7255MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>1e5</td> <td>-</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>\n   <tr> <td>62h</td>  <td>5.8s/it</td>  <td>86.81G</td>  <td>8*7811MB</td>  <td>8*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>1e5</td> <td>-</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>\n   <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>2e8</td> <td>-</td> <td>16</td> <td>8</td> <td>false</td> <td>false</td> </tr>\n   <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>1e5</td> <td>-</td> <td>16</td> <td>8</td> <td>false</td> <td>false</td> </tr>\n   <tr> <td>290h</td>  <td>3.48s/it</td>  <td>46.53G</td>  <td>1*6655MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>2e8</td> <td>2e8</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>\n   <tr> <td>380h</td>  <td>4.5s/it</td>  <td>43.48G</td>  <td>1*5263MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e5</td> <td>1e5</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>\n   <tr> <td>215h</td>  <td>4.9s/it</td>  <td>47.31G</td>  <td>2*5019MB</td>  <td>2*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e5</td> <td>1e5</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>\n   <tr> <td>1370h</td>  <td>64s/it</td>  <td>57.55G</td>  <td>4*4701MB</td>  <td>4*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e5</td> <td>1e5</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>\n   <tr> <td>948h</td>  <td>90s/it</td>  <td>72.54G</td>  <td>8*4585MB</td>  <td>8*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e5</td> <td>1e5</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>\n   <tr> <td rowspan=\"7\">Pangu-2.6B</td> <td rowspan=\"7\">SFT & Reward Data的验证集，共1万条样本</td> <td>2h</td>  <td>5.76s/it</td>  <td>67.86G</td>  <td>1*15631MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>2e8</td> <td>-</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>\n   <tr> <td>2.1h</td>  <td>6.15s/it</td>  <td>67.88G</td>  <td>1*15705MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>1e5</td> <td>-</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>\n   <tr> <td>4.5h</td>  <td>13.3s/it</td>  <td>81.02G</td>  <td>1*15449MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>2e8</td> <td>2e8</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>\n   <tr> <td>11.5h</td>  <td>8.2s/it</td>  <td>75.89G</td>  <td>1*15299MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e5</td> <td>1e5</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>\n   <tr> <td>5.5h</td>  <td>7.8s/it</td>  <td>81.16G</td>  <td>2*14851MB</td>  <td>2*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e5</td> <td>1e5</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>\n   <tr> <td>6.2h</td>  <td>18.3s/it</td>  <td>97.31G</td>  <td>4*14389MB</td>  <td>4*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e5</td> <td>1e5</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>\n   <tr> <td>6.6h</td>  <td>38s/it</td>  <td>118.82G</td>  <td>8*14335MB</td>  <td>8*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e5</td> <td>1e5</td> <td>2</td> <td>8</td> <td>false</td> <td>false</td> </tr>\n   <tr> <td rowspan=\"14\">ChatGLM-6B</td> <td rowspan=\"14\">SFT & Reward Data的验证集，共1万条样本</td> <td>-</td>  <td>-</td>  <td>120.45G</td>  <td>OOM</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>1e5</td> <td>-</td> <td>1</td> <td>8</td> <td>true</td> <td>true</td> </tr>\n   <tr> <td>-</td>  <td>-</td>  <td>120.48G</td>  <td>OOM</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>1e3</td> <td>-</td> <td>1</td> <td>8</td> <td>false</td> <td>true</td> </tr>\n   <tr> <td>-</td>  <td>-</td>  <td>153.02G</td>  <td>OOM</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e2</td> <td>1e2</td> <td>1</td> <td>8</td> <td>false</td> <td>true</td> </tr>\n   <tr> <td>-</td>  <td>-</td>  <td>154G</td>  <td>OOM</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>2e8</td> <td>2e8</td> <td>1</td> <td>8</td> <td>true</td> <td>true</td> </tr>\n   <tr> <td>21.2h</td>  <td>60s/it</td>  <td>154G</td>  <td>1*10443MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>2e8</td> <td>auto</td> <td>1</td> <td>8</td> <td>true</td> <td>true</td> </tr>\n   <tr> <td>21.5h</td>  <td>60s/it</td>  <td>152.81G</td>  <td>1*10409MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e5</td> <td>1e5</td> <td>1</td> <td>8</td> <td>true</td> <td>true</td> </tr>\n   <tr> <td>23.5h</td>  <td>65s/it</td>  <td>153.36G</td>  <td>1*9229MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e3</td> <td>1e3</td> <td>1</td> <td>8</td> <td>true</td> <td>true</td> </tr>\n   <tr> <td>14h</td>  <td>80s/it</td>  <td>158.21G</td>  <td>2*8631MB</td>  <td>2*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e3</td> <td>1e3</td> <td>1</td> <td>8</td> <td>true</td> <td>true</td> </tr>\n   <tr> <td>7.8h</td>  <td>90s/it</td>  <td>168.38G</td>  <td>4*6743MB</td>  <td>4*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e3</td> <td>1e3</td> <td>1</td> <td>8</td> <td>true</td> <td>true</td> </tr>\n   <tr> <td>4h</td>  <td>90s/it</td>  <td>189.34G</td>  <td>8*6729MB</td>  <td>8*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e3</td> <td>1e3</td> <td>1</td> <td>8</td> <td>true</td> <td>true</td> </tr>\n   <tr> <td>1h</td>  <td>100s/it</td>  <td>189.38G</td>  <td>8*10047MB</td>  <td>8*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e3</td> <td>1e3</td> <td>4</td> <td>8</td> <td>true</td> <td>true</td> </tr>\n   <tr> <td>50min</td>  <td>40s/it</td>  <td>189.39G</td>  <td>8*14763MB</td>  <td>8*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e3</td> <td>1e3</td> <td>8</td> <td>2</td> <td>true</td> <td>true</td> </tr>\n   <tr> <td>35min</td>  <td>113s/it</td>  <td>189.39G</td>  <td>8*14763MB</td>  <td>8*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e3</td> <td>1e3</td> <td>8</td> <td>8</td> <td>true</td> <td>true</td> </tr>\n   <tr> <td>-</td>  <td>-</td>  <td>189.34G</td>  <td>OOM</td>  <td>8*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e3</td> <td>1e3</td> <td>10</td> <td>8</td> <td>true</td> <td>true</td> </tr>\n   <tr> <td rowspan=\"11\">GLM-10B-Chinese</td> <td rowspan=\"11\">SFT & Reward Data的验证集，共1万条样本</td> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>2e8</td> <td>2e8</td> <td>1</td> <td>8</td> <td>true</td> <td>false</td> </tr>\n   <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>2e8</td> <td>auto</td> <td>1</td> <td>8</td> <td>true</td> <td>false</td> </tr>\n   <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e5</td> <td>1e5</td> <td>1</td> <td>8</td> <td>true</td> <td>false</td> </tr>\n   <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e3</td> <td>1e3</td> <td>1</td> <td>8</td> <td>true</td> <td>false</td> </tr>\n   <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e2</td> <td>1e2</td> <td>1</td> <td>8</td> <td>true</td> <td>false</td> </tr>\n   <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>2*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e2</td> <td>1e2</td> <td>1</td> <td>8</td> <td>true</td> <td>false</td> </tr>\n   <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>4*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e2</td> <td>1e2</td> <td>1</td> <td>8</td> <td>true</td> <td>false</td> </tr>\n   <tr> <td>-</td>  <td>-</td>  <td>OOM</td>  <td>-</td>  <td>8*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e2</td> <td>1e2</td> <td>1</td> <td>8</td> <td>true</td> <td>false</td> </tr>\n   <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>4*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e2</td> <td>1e2</td> <td>1</td> <td>8</td> <td>true</td> <td>true</td> </tr>\n   <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>6*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e2</td> <td>1e2</td> <td>1</td> <td>8</td> <td>true</td> <td>true</td> </tr>\n   <tr> <td>-</td>  <td>-</td>  <td>OOM</td>  <td>-</td>  <td>8*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e2</td> <td>1e2</td> <td>1</td> <td>8</td> <td>true</td> <td>true</td> </tr>\n</table>\n</details>\n\n**PS**: deepspeed的参数介绍和调优经验，可参见[DeepSpeed Configuration](src/resources/config/deepspeed/README.md)\n\n### 6. LoRA实验\n\n为验证LoRA的训练效率提升，进行了benchmarking\n\n- 实验场景：SFT阶段训练\n- 实验数据：SFT & Reward Data的验证集，共1万条样本\n- 实验参数：```max_sequence_length=512, lora_alpha=1, lora_train_bias='none'```\n\n<details>\n<summary><b>LoRA实验结果</b></summary>\n<table>\n   <tr> <td>模型</td> <td>LoRA rank</td> <td>可训练参数量</td> <td>deepspeed</td> <td>batch size</td> <td>GPU型号和数量</td> <td>显存使用量</td> <td>单条样本耗时</td> <td>整体耗时/epoch</td> </tr>\n   <tr> <td rowspan=\"8\">Pangu-2.6B</td>  <td>-</td>  <td>2.6B</td>  <td>-</td>  <td>8</td>  <td>1*A100 80G</td>  <td>1*79421MB</td>  <td>9.66s/it</td>  <td>12.5min</td> </tr>\n   <tr> <td>1000</td>  <td>1.5B</td>  <td>-</td>  <td>8</td>  <td>1*A100 80G</td>  <td>1*76129MB</td>  <td>11.61s/it</td>  <td>15min</td> </tr>\n   <tr> <td>500</td>  <td>758MB</td>  <td>-</td>  <td>12</td>  <td>1*A100 80G</td>  <td>1*77179MB</td>  <td>16.2s/it</td>  <td>14min</td> </tr>\n   <tr> <td>100</td>  <td>151MB</td>  <td>-</td>  <td>16</td>  <td>1*A100 80G</td>  <td>1*81103MB</td>  <td>18.6s/it</td>  <td>12min</td> </tr>\n   <tr> <td>50</td>  <td>75MB</td>  <td>-</td>  <td>16</td>  <td>1*A100 80G</td>  <td>1*80809MB</td>  <td>17.8s/it</td>  <td>11.5min</td> </tr>\n   <tr> <td>10</td>  <td>15MB</td>  <td>-</td>  <td>16</td>  <td>1*A100 80G</td>  <td>1*78735MB</td>  <td>17.6s/it</td>  <td>11.5min</td> </tr>\n   <tr> <td>100</td>  <td>151MB</td>  <td>stage=2, w offloading</td>  <td>24</td>  <td>1*A100 80G</td>  <td>1*76933MB</td>  <td>25.5s/it</td>  <td>11min</td> </tr>\n   <tr> <td>100</td>  <td>151MB</td>  <td>stage=3, w offloading</td>  <td>24</td>  <td>1*A100 80G</td>  <td>1*77259MB</td>  <td>46.5s/it</td>  <td>20min</td> </tr>\n   <tr> <td rowspan=\"3\">ChatGLM-6B</td>  <td>-</td>  <td>6.2B</td>  <td>-</td>  <td>3</td>  <td>1*A100 80G</td>  <td>1*79206MB</td>  <td>6.7s/it</td>  <td>23.5min</td> </tr>\n   <tr> <td>1000</td>  <td>1.9B</td>  <td>-</td>  <td>6</td>  <td>1*A100 80G</td>  <td>1*78840MB</td>  <td>12.8s/it</td>  <td>22.5min</td> </tr>\n   <tr> <td>500</td>  <td>994MB</td>  <td>-</td>  <td>6</td>  <td>1*A100 80G</td>  <td>1*68832MB</td>  <td>12.4s/it</td>  <td>21.5min</td> </tr>\n</table>\n</details>"
  },
  {
    "path": "docs/INSTALL_ON_WINDOWS.md",
    "content": "大模型相关的常见package在***Windows***上的安装方案。在以下配置上已测试成功: ```windows11 + python3.11 + cuda12.3 + cudnn8.9.5.30 + torch2.1.0+cu121 + ninja1.11.1 + cmake3.28.0-rc5 + Visual Studio 2022```\n\n## 0. Prerequisite\n#### a. 确认NVIDIA GPU架构\n可根据下图\n![image info](../images/torch_cuda_list.png \"torch_cuda_list\")\n\n或运行```torch.cuda.get_device_capability()```获取\n\n#### b. 软件安装\n需安装Visual Studio、ninja、cmake和CUDA\n\n- Microsoft Visual Studio 2022安装注意事项如下，`单个组件`需勾选6个额外模块 ，[参考链接](https://blog.csdn.net/feking2/article/details/130160845)：\n\n![image info](../images/microsoft_visual_studio_components.png \"microsoft_visual_studio_components\")\n\n- 安装顺序：必须先安装Visual Studio，再安装CUDA\n\n## 1. apex\n目前为止没有成功\n\n## 2. cutlass\n打开Git Bash\n```bash\n# 设定环境变量\nexport CUDACXX=/c/Program\\ Files/NVIDIA\\ GPU\\ Computing\\ Toolkit/CUDA/v12.3/bin/nvcc\n\ngit clone https://github.com/NVIDIA/cutlass.git\nmkdir build && cd build\n\n# 显卡架构的数字根据自己情况修改\ncmake .. -DCUTLASS_NVCC_ARCHS=89 -DCUTLASS_ENABLE_TESTS=OFF -DCUTLASS_ENABLE_LIBRARY=OFF -DCUTLASS_UNITY_BUILD_ENABLED=ON\ncmake --build . --config Release -j 8\n```\n在以下commit-id: `b5d8a5d9ccb435268e2215144dca162b0fc6a826`编译成功\n\n## 3. flash attention\n打开Git Bash\n```bash\ngit clone https://github.com/Dao-AILab/flash-attention.git\ncd flash-attention\n# IMPORTANT: do not manually copy cmake-built (or incorrect version?) cutlass code to flash-attention/csrc/cutlass\nMAX_JOBS=8 python setup.py install\n```\n在以下commit-id: `017716451d446e464dde9aca3a3c1ed2209caaa9`安装成功\n\n- benchmark\n\n  - 测试代码: https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py\n\n  - 测试设备：单张4090\n\n  <details>\n    <summary>结果</summary>\n    <table>\n      <tr> <td rowspan=\"2\" style=\"text-align:center\">is causal</td>  <td rowspan=\"2\" style=\"text-align:center\">head dim</td> <td rowspan=\"2\" style=\"text-align:center\">batch size</td> <td rowspan=\"2\" style=\"text-align:center\">seq length</td> <td rowspan=\"2\" style=\"text-align:center\">Model</td> <td colspan=\"3\" style=\"text-align:center\">Speed (TFLOPS/s)</td> </tr>\n      <tr> <td>fwd</td>  <td>bwd</td>  <td>fwd+bwd</td> </tr>\n      <tr> <td>false</td> <td>64</td> <td>32</td> <td>512</td> <td>Flash2</td> <td>149.21</td> <td>91.05</td> <td>102.46</td> </tr>\n      <tr> <td>false</td> <td>64</td> <td>32</td> <td>512</td> <td>Pytorch</td> <td>20.04</td> <td>22.91</td> <td>22.01</td> </tr>\n      <tr> <td>false</td> <td>64</td> <td>32</td> <td>512</td> <td>xformers.cutlass</td> <td>126.21</td> <td>74.03</td> <td>83.95</td> </tr>\n      <tr> <td>false</td> <td>64</td> <td>32</td> <td>512</td> <td>xformers.flash</td> <td>161.29</td> <td>99.23</td> <td>111.49</td> </tr>\n      <tr> <td>false</td> <td>64</td> <td>16</td> <td>1024</td> <td>Flash2</td> <td>167.81</td> <td>124.49</td> <td>134.40</td> </tr>\n      <tr> <td>false</td> <td>64</td> <td>16</td> <td>1024</td> <td>Pytorch</td> <td>21.43</td> <td>25.68</td> <td>24.30</td> </tr>\n      <tr> <td>false</td> <td>64</td> <td>16</td> <td>1024</td> <td>xformers.cutlass</td> <td>132.98</td> <td>78.03</td> <td>88.48</td> </tr>\n      <tr> <td>false</td> <td>64</td> <td>16</td> <td>1024</td> <td>xformers.flash</td> <td>164.71</td> <td>124.29</td> <td>133.66</td> </tr>\n      <tr> <td>false</td> <td>64</td> <td>8</td> <td>2048</td> <td>Flash2</td> <td>169.75</td> <td>123.30</td> <td>133.76</td> </tr>\n      <tr> <td>false</td> <td>64</td> <td>8</td> <td>2048</td> <td>Pytorch</td> <td>25.58</td> <td>27.19</td> <td>26.71</td> </tr>\n      <tr> <td>false</td> <td>64</td> <td>8</td> <td>2048</td> <td>xformers.cutlass</td> <td>136.01</td> <td>77.43</td> <td>88.29</td> </tr>\n      <tr> <td>false</td> <td>64</td> <td>8</td> <td>2048</td> <td>xformers.flash</td> <td>168.51</td> <td>13.641</td> <td>144.26</td> </tr>\n      <tr> <td>false</td> <td>64</td> <td>4</td> <td>4096</td> <td>Flash2</td> <td>141.27</td> <td>142.01</td> <td>141.79</td> </tr>\n      <tr> <td>false</td> <td>64</td> <td>4</td> <td>4096</td> <td>Pytorch</td> <td>26.39</td> <td>0.64</td> <td>0.89</td> </tr>\n      <tr> <td>false</td> <td>64</td> <td>4</td> <td>4096</td> <td>xformers.cutlass</td> <td>134.51</td> <td>67.28</td> <td>78.49</td> </tr>\n      <tr> <td>false</td> <td>64</td> <td>4</td> <td>4096</td> <td>xformers.flash</td> <td>5.52</td> <td>152.40</td> <td>17.71</td> </tr>\n      <tr> <td>true</td> <td>64</td> <td>32</td> <td>512</td> <td>Flash2</td> <td>2.13</td> <td>43.81</td> <td>6.65</td> </tr>\n      <tr> <td>true</td> <td>64</td> <td>32</td> <td>512</td> <td>Pytorch</td> <td>7.49</td> <td>10.79</td> <td>9.58</td> </tr>\n      <tr> <td>true</td> <td>64</td> <td>32</td> <td>512</td> <td>xformers.cutlass</td> <td>100.97</td> <td>56.74</td> <td>64.85</td> </tr>\n      <tr> <td>true</td> <td>64</td> <td>32</td> <td>512</td> <td>xformers.flash</td> <td>2.38</td> <td>7.61</td> <td>4.68</td> </tr>\n      <tr> <td>true</td> <td>64</td> <td>16</td> <td>1024</td> <td>Flash2</td> <td>3.36</td> <td>89.96</td> <td>10.73</td> </tr>\n      <tr> <td>true</td> <td>64</td> <td>16</td> <td>1024</td> <td>Pytorch</td> <td>8.35</td> <td>12.27</td> <td>10.82</td> </tr>\n      <tr> <td>true</td> <td>64</td> <td>16</td> <td>1024</td> <td>xformers.cutlass</td> <td>115.96</td> <td>58.43</td> <td>68.08</td> </tr>\n      <tr> <td>true</td> <td>64</td> <td>16</td> <td>1024</td> <td>xformers.flash</td> <td>3.06</td> <td>90.85</td> <td>9.87</td> </tr>\n      <tr> <td>true</td> <td>64</td> <td>8</td> <td>2048</td> <td>Flash2</td> <td>5.24</td> <td>112.91</td> <td>16.44</td> </tr>\n      <tr> <td>true</td> <td>64</td> <td>8</td> <td>2048</td> <td>Pytorch</td> <td>8.76</td> <td>13.39</td> <td>11.63</td> </tr>\n      <tr> <td>true</td> <td>64</td> <td>8</td> <td>2048</td> <td>xformers.cutlass</td> <td>123.37</td> <td>67.61</td> <td>77.63</td> </tr>\n      <tr> <td>true</td> <td>64</td> <td>8</td> <td>2048</td> <td>xformers.flash</td> <td>4.37</td> <td>115.97</td> <td>13.99</td> </tr>\n      <tr> <td>true</td> <td>64</td> <td>4</td> <td>4096</td> <td>Flash2</td> <td>7.26</td> <td>134.79</td> <td>22.40</td> </tr>\n      <tr> <td>true</td> <td>64</td> <td>4</td> <td>4096</td> <td>Pytorch</td> <td>0.22</td> <td>0.35</td> <td>0.30</td> </tr>\n      <tr> <td>true</td> <td>64</td> <td>4</td> <td>4096</td> <td>xformers.cutlass</td> <td>129.03</td> <td>70.07</td> <td>80.60</td> </tr>\n      <tr> <td>true</td> <td>64</td> <td>4</td> <td>4096</td> <td>xformers.flash</td> <td>6.25</td> <td>103.89</td> <td>19.02</td> </tr>\n    </table>\n  </details>\n\n\n## 4. xformers\n打开任意终端，如：cmd、powershell或git bash\n```cmd\npip3 install -U xformers --index-url https://download.pytorch.org/whl/cu121\n```\n\n## 5. deepspeed\n打开Git Bash\n```bash\ngit clone https://github.com/microsoft/DeepSpeed.git\ncd deepspeed\nrm -rf build\n# 显卡架构的数字根据自己情况修改\nTORCH_CUDA_ARCH_LIST=8.9 DS_BUILD_OPS=0 DS_BUILD_AIO=0 DS_BUILD_CCL_COMM=1 DS_BUILD_CPU_ADAM=1 DS_BUILD_CPU_ADAGRAD=1 DS_BUILD_TRANSFORMER=1 DS_BUILD_TRANSFORMER_INFERENCE=0 DS_BUILD_INFERENCE_CORE_OPS=0 DS_BUILD_CUTLASS_OPS=0 python setup.py build_ext -j8 bdist_wheel 2>&1 | tee build.log\n# 完成后会在dist目录生成whl文件\npip install dist/deepspeed-0.12.4+a3926bbb-cp311-cp311-win_amd64.whl\n```\n在以下commit-id: `a3926bbbf6d0025b5c6076a280e6b91ebd08aada`安装成功\n\n***PS***：由于deepspeed需要调用NCCL backend进行分布式训练，但NCCL不支持Windows，所以目前虽然可以安装deepspeed，但无法进行训练\n"
  },
  {
    "path": "examples/data_prepare.sh",
    "content": "#!/bin/bash\n\nMODEL=\"pangu-2.6B\"\n\nROOT=\"/mnt/sfevol775196/sunzeye273\"\n#ROOT=\"/mnt/share-pa002-vol682688-prd/sunzeye273\"\n#ROOT=\"/mnt/pa002-28359-vol543625-private\"\n#ROOT=\"/root/autodl-tmp/\"\nDATR_DIR=$ROOT/Data/chatgpt/raw\n#MAIN=$ROOT/Code/chatgpt/src/data_prepare.py\nMAIN=$ROOT/Code/RLHF/src/data_prepare.py\nMODEL_PATH=$ROOT/Data/models/$MODEL\n#MODEL_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$MODEL\nOUTPUT_DIR=$ROOT/Data/chatgpt/processed\n\n#cd $ROOT/Code/chatgpt || exit\ncd $ROOT/Code/RLHF || exit\nmkdir -p $OUTPUT_DIR\n\npython $MAIN \\\n  --data_dir $DATR_DIR \\\n  --output_dir $OUTPUT_DIR \\\n  --model_name_or_path $MODEL_PATH"
  },
  {
    "path": "examples/eval_pretrain.sh",
    "content": "#!/bin/bash\n\n# C-Eval and MMLU benchamarks\nTASK=\"ceval\"\nMODEL=\"llama-7B\"\nROOT=\"/mnt/sfevol775196/sunzeye273\"\n#ROOT=\"/mnt/share-pa002-vol682688-prd/sunzeye273\"\n#ROOT=\"/mnt/pa002-28359-vol543625-private\"\n#ROOT=\"/root/autodl-tmp\"\nDATR_DIR=$ROOT/Data/chatgpt/raw/$TASK\n#MAIN=$ROOT/Code/chatgpt/src/eval_pretrain.py\nMAIN=$ROOT/Code/RLHF/src/eval_pretrain.py\nMODEL_PATH=$ROOT/Data/models/$MODEL\n#MODEL_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$MODEL\nOUTPUT_DIR=$ROOT/Data/chatgpt/output/pretrain/$MODEL\nEVAL_FILENAME=\"val\"\nTRAIN_FILENAME=\"dev\"\nCHECKPOINT=$ROOT/Data/chatgpt/output/pretrain/$MODEL\nSHOTS=5\nMAX_LENGTH=1280\n\ncd $ROOT/Code/RLHF || exit\n#    cd $ROOT/Code/chatgpt || exit\nmkdir -p $OUTPUT_DIR\n\n#CUDA_VISIBLE_DEVICES=1 deepspeed --master_port 5008 $MAIN \\\n#deepspeed --num_gpus 1 $MAIN \\\npython $MAIN \\\n  --device_map \"auto\" \\\n  --data_dir $DATR_DIR \\\n  --output_dir $OUTPUT_DIR \\\n  --model_name_or_path $MODEL_PATH \\\n  --task $TASK \\\n  --train_filename $TRAIN_FILENAME \\\n  --eval_filename $EVAL_FILENAME \\\n  --checkpoint $CHECKPOINT \\\n  --max_length $MAX_LENGTH \\\n  --max_few_shot $SHOTS \\\n  --max_length_generation 1 \\\n  > out/eval_pretrain_${TASK}_${MODEL}_${EVAL_FILENAME}_${SHOTS}-shots_${MAX_LENGTH}_\"`date \"+%Y-%m-%d-%H:%M:%S\"`\".log 2>&1 &\n\n## Traditional NLP benchmark Evaluations\n#for TASK in  \"cluewsc2020\" \"afqmc\" \"csl\" \"iflytek\" \"ocnli\" \"cmnli\" \"tnews\" \"c3\" \"cmrc2018\" \"chid\"\n#do\n#  for MODEL in \"pangu-350M\" \"pangu-2.6B\" \"glm-350M-chinese\" \"glm-10B-chinese\" \"pangu-13B\"\n#  do\n#    ROOT=\"/mnt/sfevol775196/sunzeye273\"\n#    #ROOT=\"/mnt/share-pa002-vol682688-prd/sunzeye273\"\n#    #ROOT=\"/mnt/pa002-28359-vol543625-private\"\n#    #ROOT=\"/root/autodl-tmp\"\n#    DATR_DIR=$ROOT/Data/chatgpt/raw/$TASK\n##    MAIN=$ROOT/Code/chatgpt/src/eval_pretrain.py\n#    MAIN=$ROOT/Code/RLHF/src/eval_pretrain.py\n#    MODEL_PATH=$ROOT/Data/models/$MODEL\n#    #MODEL_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$MODEL\n#    OUTPUT_DIR=$ROOT/Data/chatgpt/output/pretrain/$MODEL\n#    EVAL_FILENAME=\"dev.json\"\n#    TRAIN_FILENAME=\"train.json\"\n#    case $MODEL in\n#       \"pangu-2.6B\")\n#          BATCH_SIZE=8\n#          ;;\n#       *)\n#         BATCH_SIZE=32\n#         ;;\n#    esac\n#\n#    cd $ROOT/Code/RLHF || exit\n##    cd $ROOT/Code/chatgpt || exit\n#    mkdir -p $OUTPUT_DIR\n#\n#    #CUDA_VISIBLE_DEVICES=1 deepspeed --master_port 5008 $MAIN \\\n#    #deepspeed --num_gpus 1 $MAIN \\\n#    python $MAIN \\\n#      --device_map auto \\\n#      --data_dir $DATR_DIR \\\n#      --output_dir $OUTPUT_DIR \\\n#      --model_name_or_path $MODEL_PATH \\\n#      --task $TASK \\\n#      --max_length 512 \\\n#      --train_filename $TRAIN_FILENAME \\\n#      --eval_filename $EVAL_FILENAME \\\n#      --eval_batch_size $BATCH_SIZE \\\n#      --top_p 0.8 \\\n#      --temperature 0.8 \\\n#      --num_return_sequences 1 \\\n#      --max_length_generation 100 \\\n#      > out/eval_pretrain_${MODEL}_${TASK}_\"`date \"+%Y-%m-%d-%H:%M:%S\"`\".log 2>&1\n#  done\n#done"
  },
  {
    "path": "examples/eval_reward.sh",
    "content": "#!/bin/bash\n\nMODEL=\"chatglm-6B\"\n\nROOT=\"/mnt/sfevol775196/sunzeye273\"\n#ROOT=\"/mnt/share-pa002-vol682688-prd/sunzeye273\"\n#ROOT=\"/mnt/pa002-28359-vol543625-private\"\n#ROOT=\"/root/autodl-tmp\"\nDATR_DIR=$ROOT/Data/chatgpt/processed\n#MAIN=$ROOT/Code/chatgpt/src/train_reward.py\nMAIN=$ROOT/Code/RLHF/src/train_reward.py\nTOKENIZER_PATH=$ROOT/Data/models/$MODEL\n#TOKENIZER_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$MODEL\nMODEL_PATH=$ROOT/Data/chatgpt/output/reward/${MODEL}\nOUTPUT_DIR=$ROOT/Data/chatgpt/output/reward/$MODEL\nEVAL_FILENAME=\"dev_data_external_v1.jsonl\"\n#CHECKPOINT=\"${ROOT}/Data/chatgpt/output/reward/${MODEL}/pytorch_modelstar.bin\"\n\n#cd $ROOT/Code/chatgpt || exit\ncd $ROOT/Code/RLHF || exit\nmkdir -p $OUTPUT_DIR\n\n#python $MAIN \\\nCUDA_LAUNCH_BLOCKING=1 deepspeed $MAIN \\\n  --data_dir $DATR_DIR \\\n  --output_dir $OUTPUT_DIR \\\n  --model_name_or_path $MODEL_PATH \\\n  --tokenizer_path $TOKENIZER_PATH \\\n  --max_length 512 \\\n  --logging_steps 10 \\\n  --deepspeed_config \"stage-3.json\" \\\n  --do_eval \\\n  --eval_filename $EVAL_FILENAME \\\n  --eval_batch_size 20 \\\n  > out/eval_reward_${MODEL}_\"`date \"+%Y-%m-%d-%H:%M:%S\"`\".log 2>&1 &\n"
  },
  {
    "path": "examples/pred_pretrain.sh",
    "content": "#!/bin/bash\n\nMODEL=\"llama-7B\"\n\nROOT=\"/mnt/sfevol775196/sunzeye273\"\n#ROOT=\"/mnt/share-pa002-vol682688-prd/sunzeye273\"\n#ROOT=\"/mnt/pa002-28359-vol543625-private\"\n#ROOT=\"/root/autodl-tmp\"\nDATR_DIR=$ROOT/Data/chatgpt/processed\n#MAIN=$ROOT/Code/chatgpt/src/pretrain.py\nMAIN=$ROOT/Code/RLHF/src/pretrain.py\nMODEL_PATH=$ROOT/Data/models/$MODEL\n#MODEL_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$MODEL\nOUTPUT_DIR=$ROOT/Data/chatgpt/output/pretrain/$MODEL/checkpoint-2000\nCHECKPOINT=$OUTPUT_DIR/pytorch_model.bin\nTEST_FILENAME=\"test_prompts.jsonl\"\nOUTPUT_FILENAME=\"output_${MODEL}.jsonl\"\n\n#cd $ROOT/Code/chatgpt || exit\ncd $ROOT/Code/RLHF || exit\nmkdir -p $OUTPUT_DIR\n\n#CUDA_VISIBLE_DEVICES=1 deepspeed --master_port 5008 $MAIN \\\npython $MAIN \\\n  --device_map \"auto\" \\\n  --data_dir $DATR_DIR \\\n  --output_dir $OUTPUT_DIR \\\n  --model_name_or_path $MODEL_PATH \\\n  --checkpoint $CHECKPOINT \\\n  --max_length 1024 \\\n  --max_length_generation 512 \\\n  --bits 16 \\\n  --do_pred \\\n  --test_filename $TEST_FILENAME \\\n  --output_filename $OUTPUT_FILENAME \\\n  > out/pred_pretrain_${MODEL}_\"`date \"+%Y-%m-%d-%H:%M:%S\"`\".log 2>&1 &\n"
  },
  {
    "path": "examples/pred_reward.sh",
    "content": "#!/bin/bash\n\nMODEL=\"chatglm2-6B\"\n\nROOT=\"/mnt/sfevol775196/sunzeye273\"\n#ROOT=\"/mnt/share-pa002-vol682688-prd/sunzeye273\"\n#ROOT=\"/mnt/pa002-28359-vol543625-private\"\n#ROOT=\"/root/autodl-tmp\"\nDATR_DIR=$ROOT/Data/chatgpt/processed\n#MAIN=$ROOT/Code/chatgpt/src/train_reward.py\nMAIN=$ROOT/Code/RLHF/src/train_reward.py\nTOKENIZER_PATH=$ROOT/Data/models/$MODEL\n#TOKENIZER_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$MODEL\nMODEL_PATH=$ROOT/Data/chatgpt/output/sft/${MODEL}\nOUTPUT_DIR=$ROOT/Data/chatgpt/output/reward/$MODEL\nTEST_FILENAME=\"test_data_external_v1.jsonl\"\nOUTPUT_FILENAME=\"result_external_v1.tsv\"\nCHECKPOINT=\"${ROOT}/Data/chatgpt/output/reward/${MODEL}/pytorch_modelstar.bin\"\n\n#cd $ROOT/Code/chatgpt || exit\ncd $ROOT/Code/RLHF || exit\nmkdir -p $OUTPUT_DIR\n\npython $MAIN \\\n  --local_rank 0 \\\n  --device_map \"auto\" \\\n  --data_dir $DATR_DIR \\\n  --output_dir $OUTPUT_DIR \\\n  --model_name_or_path $MODEL_PATH \\\n  --tokenizer_path $TOKENIZER_PATH \\\n  --checkpoint $CHECKPOINT \\\n  --max_length 512 \\\n  --logging_steps 50 \\\n  --do_pred \\\n  --test_filename $TEST_FILENAME \\\n  --eval_batch_size 256 \\\n  --output_filename $OUTPUT_FILENAME \\\n  > out/pred_reward_${MODEL}_\"`date \"+%Y-%m-%d-%H:%M:%S\"`\".log 2>&1 &\n"
  },
  {
    "path": "examples/pretrain.sh",
    "content": "#!/bin/bash\n\nMODEL=\"llama-7B\"\n\nROOT=\"/mnt/sfevol775196/sunzeye273\"\n#ROOT=\"/mnt/share-pa002-vol682688-prd/sunzeye273\"\n#ROOT=\"/mnt/pa002-28359-vol543625-private\"\n#ROOT=\"/root/autodl-tmp\"\nDATR_DIR=$ROOT/Data/chatgpt/processed\n#MAIN=$ROOT/Code/chatgpt/src/pretrain.py\nMAIN=$ROOT/Code/RLHF/src/pretrain_wo_trainer.py\nMODEL_PATH=$ROOT/Data/models/$MODEL\n#MODEL_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$MODEL\nOUTPUT_DIR=$ROOT/Data/chatgpt/output/pretrain/$MODEL\nTRAIN_FILENAME=\"pretrain_data_v1.jsonl\"\nEVAL_FILENAME=\"pretrain_eval_data_v1.jsonl\"\n\n#cd $ROOT/Code/chatgpt || exit\ncd $ROOT/Code/RLHF || exit\nmkdir -p $OUTPUT_DIR\n\n#CUDA_VISIBLE_DEVICES=1 deepspeed --master_port 5008 $MAIN \\\n#python $MAIN \\\nCUDA_LAUNCH_BLOCKING=1 deepspeed $MAIN \\\n  --data_dir $DATR_DIR \\\n  --output_dir $OUTPUT_DIR \\\n  --model_name_or_path $MODEL_PATH \\\n  --max_length 2048 \\\n  --logging_steps 50 \\\n  --save_steps 1000 \\\n  --metric_for_best_model \"eval_rougeLsum\" \\\n  --learning_rate 3e-4 \\\n  --do_train \\\n  --train_filename $TRAIN_FILENAME \\\n  --concat_samples \\\n  --num_epochs 2 \\\n  --train_batch_size 8 \\\n  --gradient_accumulation_steps 8 \\\n  --warmup_steps 100 \\\n  --gradient_checkpointing \\\n  --deepspeed_config \"stage-3-no_trainer.json\" \\\n  --do_eval \\\n  --eval_filename $EVAL_FILENAME \\\n  --eval_batch_size 8 \\\n  --eval_steps 500 \\\n  > out/pretrain_${MODEL}_\"`date \"+%Y-%m-%d-%H:%M:%S\"`\".log 2>&1 &\n"
  },
  {
    "path": "examples/train_dpo.sh",
    "content": "#!/bin/bash\n\nMODEL=\"chatglm2-6B\"\n\n#ROOT=\"/mnt/sfevol775196/sunzeye273\"\nROOT=\"/mnt/pa002-28359-vol543625-private\"\n#ROOT=\"/root/autodl-tmp\"\nDATR_DIR=$ROOT/Data/chatgpt/processed\n#MAIN=$ROOT/Code/chatgpt/src/train_dpo.py\nMAIN=$ROOT/Code/RLHF/src/train_dpo.py\n#TOKENIZER_PATH=$ROOT/Data/models/$MODEL\nTOKENIZER_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$MODEL\nMODEL_PATH=$ROOT/Data/chatgpt/output/sft/$MODEL\nREFERENCE_MODEL_PATH=$ROOT/Data/chatgpt/output/sft/$MODEL\nOUTPUT_DIR=$ROOT/Data/chatgpt/output/dpo/$MODEL\nTRAIN_FILENAME=\"sft_train_v2.1.jsonl\"\nEVAL_FILENAME=\"sft_eval_v1.1.jsonl\"\nTEST_FILENAME=\"${TRAIN_FILENAME},${EVAL_FILENAME}\"\nOUTPUT_FILENAME=\"dpo_logps_v2.1.bin\"\n\n#cd $ROOT/Code/chatgpt || exit\ncd $ROOT/Code/RLHF || exit\nmkdir -p $OUTPUT_DIR\n\nif [ -f $OUTPUT_DIR/$OUTPUT_FILENAME ]\nthen\n    echo \"${OUTPUT_DIR}/${OUTPUT_FILENAME} already exists, skipping prediction stage\"\nelse\n    python $MAIN \\\n      --local_rank 0 \\\n      --device_map \"cuda:0\" \\\n      --data_dir $DATR_DIR \\\n      --output_dir $OUTPUT_DIR \\\n      --tokenizer_path $TOKENIZER_PATH \\\n      --model_name_or_path $MODEL_PATH \\\n      --max_length 512 \\\n      --logging_steps 10 \\\n      --eval_batch_size 32 \\\n      --do_pred \\\n      --test_filename $TEST_FILENAME \\\n      --output_filename $OUTPUT_FILENAME \\\n      > out/pred_dpo_${MODEL}_\"`date \"+%Y-%m-%d-%H:%M:%S\"`\".log 2>&1\nfi\n\n#CUDA_VISIBLE_DEVICES=1 deepspeed --master_port 5008 $MAIN \\\n#python $MAIN \\\nCUDA_LAUNCH_BLOCKING=1 deepspeed $MAIN \\\n  --data_dir $DATR_DIR \\\n  --output_dir $OUTPUT_DIR \\\n  --tokenizer_path $TOKENIZER_PATH \\\n  --model_name_or_path $MODEL_PATH \\\n  --max_length 512 \\\n  --logging_steps 10 \\\n  --save_steps 100 \\\n  --learning_rate 1e-5 \\\n  --do_train \\\n  --train_filename $TRAIN_FILENAME \\\n  --train_batch_size 8 \\\n  --gradient_accumulation_steps 8 \\\n  --num_epochs 5 \\\n  --gradient_checkpointing \\\n  --deepspeed_config \"stage-3.json\" \\\n  --do_eval \\\n  --eval_filename $EVAL_FILENAME \\\n  --eval_batch_size 32 \\\n  --output_filename $OUTPUT_FILENAME \\\n  > out/train_dpo_${MODEL}_\"`date \"+%Y-%m-%d-%H:%M:%S\"`\".log 2>&1 &\n"
  },
  {
    "path": "examples/train_reward.sh",
    "content": "#!/bin/bash\n\nMODEL=\"chatglm-6B\"\n\nROOT=\"/mnt/sfevol775196/sunzeye273\"\n#ROOT=\"/mnt/share-pa002-vol682688-prd/sunzeye273\"\n#ROOT=\"/mnt/pa002-28359-vol543625-private\"\n#ROOT=\"/root/autodl-tmp\"\nDATR_DIR=$ROOT/Data/chatgpt/processed\n#MAIN=$ROOT/Code/chatgpt/src/train_reward.py\nMAIN=$ROOT/Code/RLHF/src/train_reward.py\nTOKENIZER_PATH=$ROOT/Data/models/$MODEL\n#TOKENIZER_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$MODEL\nMODEL_PATH=$ROOT/Data/chatgpt/output/sft/${MODEL}\nOUTPUT_DIR=$ROOT/Data/chatgpt/output/reward/$MODEL\nTRAIN_FILENAME=\"train_data_external_v1.jsonl\"\nEVAL_FILENAME=\"dev_data_external_v1.jsonl\"\n#CHECKPOINT=\"${ROOT}/Data/chatgpt/output/sft/${MODEL}/pytorch_modelstar.bin\"\n\n#cd $ROOT/Code/chatgpt || exit\ncd $ROOT/Code/RLHF || exit\nmkdir -p $OUTPUT_DIR\n\n#python $MAIN \\\nCUDA_LAUNCH_BLOCKING=1 deepspeed $MAIN \\\n  --data_dir $DATR_DIR \\\n  --output_dir $OUTPUT_DIR \\\n  --model_name_or_path $MODEL_PATH \\\n  --tokenizer_path $TOKENIZER_PATH \\\n  --max_length 512 \\\n  --logging_steps 10 \\\n  --save_steps 100 \\\n  --metric_for_best_model \"eval_accuracy\" \\\n  --learning_rate 1e-5 \\\n  --do_train \\\n  --train_filename $TRAIN_FILENAME \\\n  --train_batch_size 24 \\\n  --gradient_accumulation_steps 8 \\\n  --num_epochs 5 \\\n  --gradient_checkpointing \\\n  --deepspeed_config \"stage-3.json\" \\\n  --do_eval \\\n  --eval_filename $EVAL_FILENAME \\\n  --eval_batch_size 24 \\\n  > out/train_reward_${MODEL}_\"`date \"+%Y-%m-%d-%H:%M:%S\"`\".log 2>&1 &\n"
  },
  {
    "path": "examples/train_rlhf.sh",
    "content": "#!/bin/bash\n\nREWARD_MODEL=\"pangu-350M\"\nSFT_MODEL=\"pangu-2.6B\"\n\nROOT=\"/mnt/sfevol775196/sunzeye273\"\n#ROOT=\"/mnt/share-pa002-vol682688-prd/sunzeye273\"\n#ROOT=\"/mnt/pa002-28359-vol543625-private\"\n#ROOT=\"/root/autodl-tmp/\"\nDATR_DIR=$ROOT/Data/chatgpt/processed\n#MAIN=$ROOT/Code/chatgpt/src/train_rlhf.py\nMAIN=$ROOT/Code/RLHF/src/train_rlhf.py\n#ACCELERATE_CONFIG=$ROOT/Code/RLHF/src/resources/ppo_model/default_accelerate_config.yaml\nTOKENIZER_PATH=$ROOT/Data/models/$REWARD_MODEL\nACTOR_MODEL_PATH=$ROOT/Data/chatgpt/output/sft/$SFT_MODEL\nCRITIC_MODEL_PATH=$ROOT/Data/models/$REWARD_MODEL\n#CRITIC_MODEL_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$REWARD_MODEL\nCRITIC_CHECKPOINT=$ROOT/Data/chatgpt/output/reward/$REWARD_MODEL/pytorch_model.bin\nOUTPUT_DIR=$ROOT/Data/chatgpt/output/rlhf/$SFT_MODEL\nTRAIN_FILENAME=\"train_data_external_v1.jsonl\"\nPRETRAIN_FILENAME=\"train_data_external_v1.jsonl\"\n\n#cd $ROOT/Code/chatgpt || exit\ncd $ROOT/Code/RLHF || exit\nmkdir -p $OUTPUT_DIR\n\n#python $MAIN \\\n#accelerate launch --main_process_port 5007 --config_file $ACCELERATE_CONFIG $MAIN \\\nCUDA_LAUNCH_BLOCKING=1 deepspeed $MAIN \\\n  --data_dir $DATR_DIR \\\n  --output_dir $OUTPUT_DIR \\\n  --tokenizer_path $TOKENIZER_PATH \\\n  --actor_model_path $ACTOR_MODEL_PATH \\\n  --critic_model_path $CRITIC_MODEL_PATH \\\n  --critic_checkpoint $CRITIC_CHECKPOINT \\\n  --max_length 512 \\\n  --max_gen_length 256 \\\n  --logging_steps 10 \\\n  --do_train \\\n  --train_filename $TRAIN_FILENAME \\\n  --pretrain_filename $PRETRAIN_FILENAME \\\n  --actor_learning_rate 1e-5 \\\n  --critic_learning_rate 1e-5 \\\n  --lr_scheduler_type cosine \\\n  --train_batch_size 4 \\\n  --ppo_train_batch_size 4 \\\n  --gradient_accumulation_steps 16 \\\n  --num_epochs 1 \\\n  --ppo_epochs 1 \\\n  --enable_hybrid_engine \\\n  --actor_zero_stage 3 \\\n  --critic_zero_stage 3 \\\n  --offload_reference_model \\\n  --actor_gradient_checkpointing \\\n  --critic_gradient_checkpointing \\\n  --release_inference_cache \\\n  > out/train_rlhf_${SFT_MODEL}_\"`date \"+%Y-%m-%d-%H:%M:%S\"`\".log 2>&1 &\n"
  },
  {
    "path": "examples/train_sft.sh",
    "content": "#!/bin/bash\n\nMODEL=\"chatglm2-6B\"\n\n#ROOT=\"/mnt/sfevol775196/sunzeye273\"\nROOT=\"/mnt/pa002-28359-vol543625-private\"\n#ROOT=\"/root/autodl-tmp\"\nDATR_DIR=$ROOT/Data/chatgpt/processed\n#MAIN=$ROOT/Code/chatgpt/src/train_sft.py\nMAIN=$ROOT/Code/RLHF/src/train_sft.py\n#TOKENIZER_PATH=$ROOT/Data/models/$MODEL\nTOKENIZER_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$MODEL\n#MODEL_PATH=$ROOT/Data/models/$MODEL\nMODEL_PATH=/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint/$MODEL\nOUTPUT_DIR=$ROOT/Data/chatgpt/output/sft/$MODEL\nTRAIN_FILENAME=\"train_data_external_v1.jsonl\"\nEVAL_FILENAME=\"dev_data_external_v1.jsonl\"\n\n#cd $ROOT/Code/chatgpt || exit\ncd $ROOT/Code/RLHF || exit\nmkdir -p $OUTPUT_DIR\n\n#CUDA_VISIBLE_DEVICES=1 deepspeed --master_port 5008 $MAIN \\\n#python $MAIN \\\nCUDA_LAUNCH_BLOCKING=1 deepspeed $MAIN \\\n  --data_dir $DATR_DIR \\\n  --output_dir $OUTPUT_DIR \\\n  --tokenizer_path $TOKENIZER_PATH \\\n  --model_name_or_path $MODEL_PATH \\\n  --max_length 512 \\\n  --logging_steps 10 \\\n  --save_steps 100 \\\n  --metric_for_best_model \"eval_rougeLsum\" \\\n  --learning_rate 1e-5 \\\n  --do_train \\\n  --train_filename $TRAIN_FILENAME \\\n  --concat_samples \\\n  --train_batch_size 8 \\\n  --gradient_accumulation_steps 8 \\\n  --num_epochs 5 \\\n  --gradient_checkpointing \\\n  --deepspeed_config \"stage-3.json\" \\\n  --do_eval \\\n  --eval_filename $EVAL_FILENAME \\\n  --eval_batch_size 32 \\\n  > out/train_sft_${MODEL}_\"`date \"+%Y-%m-%d-%H:%M:%S\"`\".log 2>&1 &\n"
  },
  {
    "path": "notebook/data_prepare.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os, time, re, random, glob, json, jieba, copy\\n\",\n    \"import numpy as np\\n\",\n    \"import pandas as pd\\n\",\n    \"from tqdm import tqdm\\n\",\n    \"import torch\\n\",\n    \"from transformers import (\\n\",\n    \"    AutoModelForCausalLM,\\n\",\n    \"    AutoModelForSeq2SeqLM,\\n\",\n    \"    AutoModelForMultipleChoice,\\n\",\n    \"    AutoTokenizer,\\n\",\n    \"    Trainer,\\n\",\n    \"    TrainingArguments,\\n\",\n    \"    default_data_collator,\\n\",\n    \"    TextGenerationPipeline\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"device=\\\"cuda:0\\\" if torch.cuda.is_available() else \\\"cpu\\\"\\n\",\n    \"from sys import platform\\n\",\n    \"if platform == \\\"linux\\\" or platform == \\\"linux2\\\":\\n\",\n    \"    # linux\\n\",\n    \"    root = \\\"/mnt/sfevol775196/sunzeye273/Data\\\"\\n\",\n    \"#     root = \\\"/mnt/share-pa002-vol682688-prd/sunzeye273/Data\\\"\\n\",\n    \"#     root = \\\"/mnt/pa002-28359-vol543625-private/Data\\\"\\n\",\n    \"#     root = \\\"/root/autodl-tmp/Data\\\"\\n\",\n    \"elif platform == \\\"darwin\\\":\\n\",\n    \"    # OS X\\n\",\n    \"    root = \\\"/Users/zeyesun/Documents/Data\\\"\\n\",\n    \"elif platform == \\\"win32\\\":\\n\",\n    \"    # Windows...\\n\",\n    \"    root = \\\"D:\\\\\\\\Data\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"CLEAN_TEXT_PATTERN = re.compile(r\\\"[\\\\r\\\\n]\\\")\\n\",\n    \"\\n\",\n    \"def clean_text(text):\\n\",\n    \"    return CLEAN_TEXT_PATTERN.sub(\\\"\\\", text)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_name = \\\"pangu-350M\\\"\\n\",\n    \"# model_name = \\\"chatglm-6B\\\"\\n\",\n    \"model_name_or_path = os.path.join(root, \\\"models\\\", model_name)\\n\",\n    \"\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_cache=False, trust_remote_code=True)\\n\",\n    \"print(tokenizer.special_tokens_map)\\n\",\n    \"print(tokenizer.all_special_ids)\\n\",\n    \"print(\\n\",\n    \"    f\\\"unk: {tokenizer.unk_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"pad: {tokenizer.pad_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"bos: {tokenizer.bos_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"eos: {tokenizer.eos_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"sep: {tokenizer.sep_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"mask: {tokenizer.mask_token_id}\\\\n\\\",\\n\",\n    \"#     f\\\"eop: {tokenizer.eop_token_id}\\\\n\\\"\\n\",\n    \"#     f\\\"sop: {tokenizer.sop_token_id}\\\\n\\\"\\n\",\n    \"#     f\\\"cls: {tokenizer.cls_token_id}\\\"\\n\",\n    \") \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Data Processing\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### weibo_summary_comments_json\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"t = time.time()\\n\",\n    \"fi = os.path.join(root, \\\"raw\\\", \\\"weibo_summary_comments_json.json\\\")\\n\",\n    \"fo = os.path.join(root, \\\"chatgpt\\\", \\\"processed\\\", \\\"weibo_summary_comments.jsonl\\\")\\n\",\n    \"ct = 0\\n\",\n    \"with open(fo, \\\"w\\\", encoding=\\\"utf-8\\\") as w:\\n\",\n    \"    with open(fi, \\\"r\\\", encoding=\\\"utf-8\\\") as r:\\n\",\n    \"        while True:\\n\",\n    \"            line = r.readline()\\n\",\n    \"            if not line\\n\",\n    \"                break\\n\",\n    \"            \\n\",\n    \"            item = json.loads(line.strip(\\\"\\\\n\\\"))\\n\",\n    \"            article = item['article'].replace(\\\" \\\", \\\"\\\")\\n\",\n    \"            abstract = item['abstract'].replace(\\\" \\\", \\\"\\\")\\n\",\n    \"            prompt = f\\\"新闻内容：{article}{tokenizer.sep_token}摘要：{abstract}{tokenizer.sep_token}评论：\\\"\\n\",\n    \"            answers = [\\n\",\n    \"                {\\n\",\n    \"                    \\\"answer\\\": k.replace(\\\" \\\", \\\"\\\"), \\n\",\n    \"                    \\\"score\\\": int(v)\\n\",\n    \"                } for (k, v) in sorted(item['comments'], key=lambda x: (int(x[1]), len(x[0])), reverse=True)\\n\",\n    \"            ]\\n\",\n    \"            w.write(json.dumps({\\\"prompt\\\": prompt, \\\"answers\\\": answers}, ensure_ascii=False)+'\\\\n')\\n\",\n    \"            ct += 1\\n\",\n    \"print(f\\\"length: {ct}, time taken: {time.time()-t} s\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### couplets\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"t1 = time.time()\\n\",\n    \"fi = os.path.join(root, \\\"raw\\\", \\\"couplets.txt\\\")\\n\",\n    \"fo = os.path.join(root, \\\"chatgpt\\\", \\\"processed\\\", \\\"couplets.jsonl\\\")\\n\",\n    \"l2 = []\\n\",\n    \"nexts = dict()\\n\",\n    \"with open(fi, \\\"r\\\", encoding=\\\"utf-8\\\") as r:\\n\",\n    \"    while True:\\n\",\n    \"        line = r.readline()\\n\",\n    \"        if not line:\\n\",\n    \"            break\\n\",\n    \"        line = line.strip(\\\"\\\\n\\\")\\n\",\n    \"        idx = len(line) // 2\\n\",\n    \"        prompt = line[:idx]\\n\",\n    \"        answer = line[idx+1:]\\n\",\n    \"        answers = [{\\\"answer\\\": answer, \\\"score\\\": 1}]\\n\",\n    \"        l2.append({\\\"prompt\\\": f\\\"上联：{prompt}{tokenizer.sep_token}下联：\\\", \\\"answers\\\": answers})\\n\",\n    \"        length = len(answer)\\n\",\n    \"        if length not in nexts:\\n\",\n    \"            nexts[length] = list()\\n\",\n    \"        nexts[length].append(answer)\\n\",\n    \"t2 = time.time()\\n\",\n    \"print(f\\\"length: {len(l2)}, # different lengths: {len(nexts)}, time taken: {t2-t1} s\\\")\\n\",\n    \"with open(fo, \\\"w\\\", encoding=\\\"utf-8\\\") as w:\\n\",\n    \"    for i, l in tqdm(enumerate(l2), desc=\\\"Processing Couplets\\\"):\\n\",\n    \"        answer = l['answers'][0]\\n\",\n    \"        length = len(answer['answer'])\\n\",\n    \"        # 上下联长度一样\\n\",\n    \"        nexts_tmp = set(nexts[length])\\n\",\n    \"        nexts_tmp.remove(answer['answer'])\\n\",\n    \"        nexts_tmp = set(nexts[length]).difference(set([answer['answer']]))\\n\",\n    \"#         nexts_tmp.remove(answer['answer'])\\n\",\n    \"        answers.extend([{\\\"answer\\\": fa, \\\"score\\\": 0} for fa in random.sample(nexts_tmp, 2)])\\n\",\n    \"        # 上下联长度不一样\\n\",\n    \"        keys = set(nexts.keys())\\n\",\n    \"        keys.remove(length)\\n\",\n    \"        answers.extend([{\\\"answer\\\": random.choice(nexts[key]), \\\"score\\\": -1} for key in random.sample(keys, 2)])\\n\",\n    \"#         answers = sorted(answers, key=lambda x: x['score'], reverse=True)\\n\",\n    \"        w.write(json.dumps({\\\"prompt\\\": l['prompt'], \\\"answers\\\": answers}, ensure_ascii=False)+'\\\\n')\\n\",\n    \"#         if i % 1000 == 0:\\n\",\n    \"#             print(f\\\"{i} samples processed, time taken: {time.time()-t2} s\\\")\\n\",\n    \"print(f\\\"length: {len(l2)}, time taken: {time.time()-t2} s\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### zhidao\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"t = time.time()\\n\",\n    \"fp = os.path.join(root, \\\"raw\\\", \\\"zhidao\\\", \\\"*.csv\\\")\\n\",\n    \"fo = os.path.join(root, \\\"chatgpt\\\", \\\"processed\\\", \\\"zhidao.jsonl\\\")\\n\",\n    \"ct = 0\\n\",\n    \"with open(fo, \\\"w\\\", encoding=\\\"utf-8\\\") as w:\\n\",\n    \"    for fi in glob.glob(fp):\\n\",\n    \"        ct = 0\\n\",\n    \"        df = pd.read_csv(fi).sort_values(by=[\\\"title\\\", \\\"is_best\\\"], ascending=False)\\n\",\n    \"        prev_title = None\\n\",\n    \"        prev_prompt = None\\n\",\n    \"        for _, val in df.iterrows():\\n\",\n    \"            if isinstance(val['question'], str) and val['question'] != val['title']:\\n\",\n    \"                prompt = f\\\"问题：{val['title']}{tokenizer.sep_token}内容：{val['question']}{tokenizer.sep_token}回答：\\\"\\n\",\n    \"            else:\\n\",\n    \"                prompt = f\\\"问题：{val['title']}{tokenizer.sep_token}回答：\\\"\\n\",\n    \"            if prev_title is not None and prev_title == val['title']:\\n\",\n    \"                answers.append({\\\"answer\\\": val['reply'], \\\"score\\\": val['is_best']})\\n\",\n    \"            else:\\n\",\n    \"                if prev_title is not None:\\n\",\n    \"#                     l3.append({\\\"prompt\\\": prev_prompt, \\\"answers\\\": copy.deepcopy(answers)})\\n\",\n    \"                    w.write(json.dumps({\\\"prompt\\\": prev_prompt, \\\"answers\\\": answers}, ensure_ascii=False)+'\\\\n')\\n\",\n    \"                answers = [{\\\"answer\\\": val['reply'], \\\"score\\\": val['is_best']}]\\n\",\n    \"            prev_prompt = prompt\\n\",\n    \"            prev_title = val['title']\\n\",\n    \"            ct += 1\\n\",\n    \"#         l3.append({\\\"prompt\\\": prev_prompt, \\\"answers\\\": copy.deepcopy(answers)})\\n\",\n    \"        w.write(json.dumps({\\\"prompt\\\": prev_prompt, \\\"answers\\\": answers}, ensure_ascii=False)+'\\\\n')\\n\",\n    \"        print(f\\\"finished processing {os.path.basename(fi)}\\\")\\n\",\n    \"print(f\\\"length: {ct}, time taken: {time.time()-t} s\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### JDData\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from html.parser import HTMLParser\\n\",\n    \"class MyHTMLParser(HTMLParser):\\n\",\n    \"    def __init__(self):\\n\",\n    \"        super().__init__()\\n\",\n    \"       #Initializing lists\\n\",\n    \"        self.start_tags = list()\\n\",\n    \"        self.end_tags = list()\\n\",\n    \"        self.start_end_tags = list()\\n\",\n    \"        self.data_list = list()\\n\",\n    \"    #HTML Parser Methods\\n\",\n    \"    def handle_starttag(self, startTag, attrs):\\n\",\n    \"        self.start_tags.append(startTag)\\n\",\n    \"    def handle_endtag(self, endTag):\\n\",\n    \"        self.end_tags.append(endTag)\\n\",\n    \"    def handle_startendtag(self,startendTag, attrs):\\n\",\n    \"        self.start_end_tags.append(startendTag)\\n\",\n    \"    def handle_data(self, data):\\n\",\n    \"        self.data_list.append(data)\\n\",\n    \"        \\n\",\n    \"t = time.time()\\n\",\n    \"fi = os.path.join(root, \\\"raw\\\", \\\"JDData\\\", \\\"*.data*\\\")\\n\",\n    \"# fo = os.path.join(root, \\\"chatgpt\\\", \\\"processed\\\", \\\"zhidao.jsonl\\\")\\n\",\n    \"ct = 0\\n\",\n    \"with open(fo, \\\"w\\\", encoding=\\\"utf-8\\\") as w:\\n\",\n    \"    for fi in glob.glob(fp):\\n\",\n    \"        ct = 0\\n\",\n    \"        with open(fi, \\\"r\\\", encoding=\\\"gbk\\\") as r:\\n\",\n    \"            line = r.readline()\\n\",\n    \"            items = line.strip(\\\"\\\\n\\\").split(\\\"\\\\t\\\")\\n\",\n    \"            parser = MyHTMLParser()\\n\",\n    \"            parser.feed(items[1])\\n\",\n    \"            for t, d in zip(parser.start_tags, parser.data_list):\\n\",\n    \"                print(f\\\"{t}: {d}\\\")\\n\",\n    \"#                 prompt = f\\\"问题：{val['title']}{tokenizer.sep_token}内容：{val['question']}{tokenizer.sep_token}回答：\\\"\\n\",\n    \"#                 answers.append({\\\"answer\\\": val['reply'], \\\"score\\\": val['is_best']})\\n\",\n    \"            ct += 1\\n\",\n    \"#         l3.append({\\\"prompt\\\": prev_prompt, \\\"answers\\\": copy.deepcopy(answers)})\\n\",\n    \"#         w.write(json.dumps({\\\"prompt\\\": prev_prompt, \\\"answers\\\": answers}, ensure_ascii=False)+'\\\\n')\\n\",\n    \"        print(f\\\"finished processing {os.path.basename(fi)}\\\")\\n\",\n    \"print(f\\\"length: {ct}, time taken: {time.time()-t} s\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### yf_amazon\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"t = time.time()\\n\",\n    \"fi = os.path.join(root, \\\"raw\\\", \\\"yf_amazon\\\", \\\"products.csv\\\")\\n\",\n    \"dfp = pd.read_csv(fi)\\n\",\n    \"fi = os.path.join(root, \\\"raw\\\", \\\"yf_amazon\\\", \\\"ratings.csv\\\")\\n\",\n    \"dfr = pd.read_csv(fi)\\n\",\n    \"fi = os.path.join(root, \\\"raw\\\", \\\"yf_amazon\\\", \\\"categories.csv\\\")\\n\",\n    \"dfc = pd.read_csv(fi)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dfp.columns\\n\",\n    \"# dfp['name'].unique().tolist()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dfp['cate_id_1'] = dfp['catIds'].apply(lambda x: x.split(\\\",\\\")[0])\\n\",\n    \"for cid1 in dfp['cate_id_1'].unique():\\n\",\n    \"    print(dfc[dfc['catId']==int(cid1)]['category'])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### dmsc\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"t = time.time()\\n\",\n    \"fi = os.path.join(root, \\\"raw\\\", \\\"dmsc\\\", \\\"movies.csv\\\")\\n\",\n    \"dfm = pd.read_csv(fi)\\n\",\n    \"print(dfm.shape)\\n\",\n    \"fi = os.path.join(root, \\\"raw\\\", \\\"dmsc\\\", \\\"ratings.csv\\\")\\n\",\n    \"dfr = pd.read_csv(fi)\\n\",\n    \"print(dfr.shape)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dfr.groupby(\\\"movieId\\\", 'rating').count()['comment']\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Chinese Classical-Modern\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"t1 = time.time()\\n\",\n    \"fp = os.path.join(root, \\\"raw\\\", \\\"Classical-Modern\\\", \\\"bitext\\\", \\\"*\\\")\\n\",\n    \"fo = os.path.join(root, \\\"chatgpt\\\", \\\"processed\\\", \\\"chinese_classical.jsonl\\\")\\n\",\n    \"l3 = []\\n\",\n    \"dicts = dict()\\n\",\n    \"for fi in glob.glob(fp):\\n\",\n    \"    name = os.path.basename(fi)\\n\",\n    \"    dicts[name] = {\\\"古文\\\": [], \\\"现代文\\\": []}\\n\",\n    \"    with open(fi, \\\"r\\\", encoding=\\\"utf-8\\\") as r:\\n\",\n    \"        for i, line in enumerate(r):\\n\",\n    \"            line = line.strip(\\\"\\\\n\\\")\\n\",\n    \"            if line.startswith(\\\"古文\\\"):\\n\",\n    \"                p1 = line[3:]\\n\",\n    \"                dicts[name]['古文'].append(p1)\\n\",\n    \"            elif line.startswith(\\\"现代文\\\"):\\n\",\n    \"                p2 = line[4:]\\n\",\n    \"                dicts[name]['现代文'].append(p2)\\n\",\n    \"            elif p1 is not None and p2 is not None:\\n\",\n    \"                pair = [(\\\"古文\\\", p1), (\\\"现代文\\\", p2)]\\n\",\n    \"                random.shuffle(pair)\\n\",\n    \"                prompt = f\\\"{pair[0][0]}：{pair[0][1]}{tokenizer.sep_token}{pair[1][0]}：\\\"\\n\",\n    \"                answers = [{\\\"answer\\\": pair[1][1], \\\"score\\\": 1}]\\n\",\n    \"                l3.append({\\\"prompt\\\": prompt, \\\"answers\\\": answers, \\\"name\\\": name})\\n\",\n    \"                p1 = None\\n\",\n    \"                p2 = None\\n\",\n    \"t2 = time.time()\\n\",\n    \"print(f\\\"length: {len(l3)}, # different names: {len(dicts)}, time taken: {t2-t1} s\\\")\\n\",\n    \"with open(fo, \\\"w\\\", encoding=\\\"utf-8\\\") as w:\\n\",\n    \"    for i, l in tqdm(enumerate(l3), desc=\\\"Processing Chinese Classical-Modern\\\"):\\n\",\n    \"        name = l['name']\\n\",\n    \"        prompt = l['prompt']\\n\",\n    \"        answer = l['answers'][0]['answer']\\n\",\n    \"        if prompt.startswith(\\\"古文\\\"):\\n\",\n    \"            answer_type = '现代文'\\n\",\n    \"        else:\\n\",\n    \"            answer_type = '古文'\\n\",\n    \"        samples_tmp = set(dicts[name][answer_type])\\n\",\n    \"        samples_tmp.remove(answer)\\n\",\n    \"        answers.extend([{\\\"answer\\\": fa, \\\"score\\\": 0} for fa in random.sample(samples_tmp, 2)])\\n\",\n    \"        keys = set(dicts.keys())\\n\",\n    \"        keys.remove(name)\\n\",\n    \"        answers.extend([{\\\"answer\\\": random.choice(dicts[key][answer_type]), \\\"score\\\": -1} for key in random.sample(keys, 2)])\\n\",\n    \"        w.write(json.dumps({\\\"prompt\\\": prompt, \\\"answers\\\": answers}, ensure_ascii=False)+'\\\\n')\\n\",\n    \"#         if i % 100 == 0:\\n\",\n    \"#             print(f\\\"{i} samples processed, time taken: {time.time()-t2} s\\\")\\n\",\n    \"print(f\\\"length: {i}, time taken: {time.time()-t} s\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Chinese Poetry\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import opencc\\n\",\n    \"converter = opencc.OpenCC('t2s.json')\\n\",\n    \"t1 = time.time()\\n\",\n    \"fp = [\\n\",\n    \"    # 四书五经\\n\",\n    \"    os.path.join(root, \\\"raw\\\", \\\"chinese-poetry\\\", \\\"lunyu\\\", \\\"lunyu.json\\\"),\\n\",\n    \"#     os.path.join(root, \\\"raw\\\", \\\"chinese-poetry\\\", \\\"mengxue\\\", \\\"*.json\\\"),\\n\",\n    \"    os.path.join(root, \\\"raw\\\", \\\"chinese-poetry\\\", \\\"sishuwujing\\\", \\\"*.json\\\"),\\n\",\n    \"    # 古体诗\\n\",\n    \"    os.path.join(root, \\\"raw\\\", \\\"chinese-poetry\\\", \\\"caocaoshiji\\\", \\\"caocao.json\\\"),\\n\",\n    \"    os.path.join(root, \\\"raw\\\", \\\"chinese-poetry\\\", \\\"shijing\\\", \\\"shijing.json\\\"),\\n\",\n    \"    # 楚辞\\n\",\n    \"    os.path.join(root, \\\"raw\\\", \\\"chinese-poetry\\\", \\\"chuci\\\", \\\"chuci.json\\\"),\\n\",\n    \"    # 诗\\n\",\n    \"    os.path.join(root, \\\"raw\\\", \\\"chinese-poetry\\\", \\\"shi\\\", \\\"poet*.json\\\"),\\n\",\n    \"    # 词\\n\",\n    \"    os.path.join(root, \\\"raw\\\", \\\"chinese-poetry\\\", \\\"ci\\\", \\\"ci*.json\\\"),\\n\",\n    \"    os.path.join(root, \\\"raw\\\", \\\"chinese-poetry\\\", \\\"nalanxingde\\\", \\\"*.json\\\"),\\n\",\n    \"    os.path.join(root, \\\"raw\\\", \\\"chinese-poetry\\\", \\\"wudai\\\", \\\"huajianji\\\", \\\"*juan.json\\\"),\\n\",\n    \"    os.path.join(root, \\\"raw\\\", \\\"chinese-poetry\\\", \\\"wudai\\\", \\\"nantang\\\", \\\"poetrys.json\\\"),\\n\",\n    \"    # 曲\\n\",\n    \"    os.path.join(root, \\\"raw\\\", \\\"chinese-poetry\\\", \\\"yuanqu\\\", \\\"yuanqu.json\\\"),\\n\",\n    \"]\\n\",\n    \"fs = [each for f in fp for each in glob.glob(f)]\\n\",\n    \"\\n\",\n    \"l5 = []\\n\",\n    \"dicts = dict()\\n\",\n    \"for fi in fs:\\n\",\n    \"    lines = json.load(open(fi, \\\"r\\\", encoding=\\\"utf-8\\\"))\\n\",\n    \"    if isinstance(lines, dict):\\n\",\n    \"        lines = [lines]\\n\",\n    \"    for i, line in enumerate(lines):\\n\",\n    \"        if \\\"lunyu\\\" in fi:\\n\",\n    \"            author = \\\"孔子\\\"\\n\",\n    \"            genre = \\\"经书\\\"\\n\",\n    \"            title = line['chapter']\\n\",\n    \"            contents = \\\"\\\".join(line['paragraphs'])\\n\",\n    \"        elif \\\"daxue\\\" in fi:\\n\",\n    \"            author = \\\"曾子\\\"\\n\",\n    \"            genre = \\\"经书\\\"\\n\",\n    \"            title = \\\"大学\\\"\\n\",\n    \"            contents = converter.convert(\\\"\\\".join(line['paragraphs'])).replace(\\\"「\\\", \\\"“\\\").replace(\\\"」\\\", \\\"”\\\")\\n\",\n    \"        elif \\\"mengzi\\\" in fi:\\n\",\n    \"            author = \\\"孟子\\\"\\n\",\n    \"            genre = \\\"经书\\\"\\n\",\n    \"            title = converter.convert(line['chapter'])\\n\",\n    \"            contents = converter.convert(\\\"\\\".join(line['paragraphs'])).replace(\\\"「\\\", \\\"“\\\").replace(\\\"」\\\", \\\"”\\\")\\n\",\n    \"        elif \\\"zhongyong\\\" in fi:\\n\",\n    \"            author = \\\"孔伋\\\"\\n\",\n    \"            genre = \\\"经书\\\"\\n\",\n    \"            title = \\\"中庸\\\"\\n\",\n    \"            contents = converter.convert(\\\"\\\".join(line['paragraphs'])).replace(\\\"「\\\", \\\"“\\\").replace(\\\"」\\\", \\\"”\\\")\\n\",\n    \"        elif \\\"caocao\\\" in fi:\\n\",\n    \"            author = \\\"曹操\\\"\\n\",\n    \"            genre = \\\"古体诗\\\"\\n\",\n    \"            title = line['title']\\n\",\n    \"            contents = \\\"\\\".join(line['paragraphs'])\\n\",\n    \"        elif \\\"shijing\\\" in fi:\\n\",\n    \"            author = \\\"诗经\\\"\\n\",\n    \"            genre = \\\"古体诗\\\"\\n\",\n    \"            title = line['chapter'] + \\\"-\\\" + line['section'] + \\\"-\\\" + line['title']\\n\",\n    \"            contents = \\\"\\\".join(line['content'])\\n\",\n    \"        elif \\\"chuci\\\" in fi:\\n\",\n    \"            author = line['author']\\n\",\n    \"            genre = \\\"楚辞\\\"\\n\",\n    \"            title = line['section'] + \\\"-\\\" + line['title']\\n\",\n    \"            contents = \\\"\\\".join(line['content'])\\n\",\n    \"        elif \\\"nalanxingde\\\" in fi:\\n\",\n    \"            author = line['author']\\n\",\n    \"            genre = \\\"词\\\"\\n\",\n    \"            title = line['title']\\n\",\n    \"            contents = \\\"\\\".join(line['para'])\\n\",\n    \"        elif \\\"huajianci\\\" in fi:\\n\",\n    \"            author = line['author']\\n\",\n    \"            genre = \\\"词\\\"\\n\",\n    \"            title = line['title']\\n\",\n    \"            contents = \\\"\\\".join(line['paragraphs'])\\n\",\n    \"        elif \\\"nantang\\\" in fi:\\n\",\n    \"            author = line['author']\\n\",\n    \"            genre = \\\"词\\\"\\n\",\n    \"            title = line['title']\\n\",\n    \"            contents = \\\"\\\".join(line['paragraphs'])\\n\",\n    \"        elif \\\"yuanqu\\\" in fi:\\n\",\n    \"            author = line['author']\\n\",\n    \"            genre = \\\"曲\\\"\\n\",\n    \"            title = line['title']\\n\",\n    \"            contents = \\\"\\\".join(line['paragraphs'])\\n\",\n    \"        elif \\\"shi\\\" in fi:\\n\",\n    \"            if len(line['paragraphs']) <= 0:\\n\",\n    \"                continue\\n\",\n    \"            author = converter.convert(line['author'])\\n\",\n    \"            genre = \\\"五言诗\\\" if len(line['paragraphs'][0]) == 12 else \\\"七言诗\\\"\\n\",\n    \"            title = converter.convert(line['title'])\\n\",\n    \"            contents = converter.convert(\\\"\\\".join(line['paragraphs']))\\n\",\n    \"        elif \\\"ci\\\" in fi:\\n\",\n    \"            author = line['author']\\n\",\n    \"            genre = \\\"词\\\"\\n\",\n    \"            title = line['rhythmic']\\n\",\n    \"            contents = \\\"\\\".join(line['paragraphs'])\\n\",\n    \"        if genre not in dicts:\\n\",\n    \"            dicts[genre] = dict()\\n\",\n    \"        if author not in dicts[genre]:\\n\",\n    \"            dicts[genre][author] = dict()\\n\",\n    \"        quantifier = \\\"篇\\\" if genre in [\\\"经书\\\", \\\"楚辞\\\"] else \\\"首\\\"\\n\",\n    \"        prompt = f\\\"以{author}的风格，写一{quantifier}{genre}，题为{title}{tokenizer.sep_token}\\\"\\n\",\n    \"        answers = [{\\\"answer\\\": contents, \\\"score\\\": 1}]\\n\",\n    \"        l5.append({\\\"prompt\\\": prompt, \\\"answers\\\": answers, \\\"genre\\\": genre, \\\"title\\\": title, \\\"author\\\": author})\\n\",\n    \"        dicts[genre][author][title] = contents\\n\",\n    \"        \\n\",\n    \"t2 = time.time()\\n\",\n    \"print(f\\\"length: {len(l5)}, # different lengths: {len(dicts)}, time taken: {t2-t1} s\\\")\\n\",\n    \"fo = os.path.join(root, \\\"chatgpt\\\", \\\"processed\\\", \\\"chinese_poetry.jsonl\\\")\\n\",\n    \"with open(fo, \\\"w\\\", encoding=\\\"utf-8\\\") as w:\\n\",\n    \"    for i, l in tqdm(enumerate(l5), desc=\\\"Processing Chinese Poetry\\\"):\\n\",\n    \"        genre = l['genre']\\n\",\n    \"        author = l['author']\\n\",\n    \"        title = l['title']\\n\",\n    \"        prompt = l['prompt']\\n\",\n    \"        answers = l['answers']\\n\",\n    \"        # 同作者其他作品-2\\n\",\n    \"        titles_tmp = set(dicts[genre][author].keys())\\n\",\n    \"        titles_tmp.remove(title)\\n\",\n    \"        if len(titles_tmp) > 0:\\n\",\n    \"            t = random.choice(list(titles_tmp))\\n\",\n    \"            answers.append({\\\"answer\\\": dicts[genre][author][t], \\\"score\\\": 0})\\n\",\n    \"        # 同体裁其他作者其他作品-1\\n\",\n    \"        authors_tmp = set(dicts[genre].keys())\\n\",\n    \"        authors_tmp.remove(author)\\n\",\n    \"        a = random.choice(list(authors_tmp))\\n\",\n    \"        t = random.choice(list(dicts[genre][a].keys()))\\n\",\n    \"        answers.append({\\\"answer\\\": dicts[genre][a][t], \\\"score\\\": -1})\\n\",\n    \"        # 不同体裁作品-0\\n\",\n    \"        genres_tmp = set(dicts.keys())\\n\",\n    \"        genres_tmp.remove(genre)\\n\",\n    \"        g = random.choice(list(genres_tmp))\\n\",\n    \"        a = random.choice(list(dicts[g].keys()))\\n\",\n    \"        t = random.choice(list(dicts[g][a].keys()))\\n\",\n    \"        answers.append({\\\"answer\\\": dicts[g][a][t], \\\"score\\\": -2})\\n\",\n    \"        w.write(json.dumps({\\\"prompt\\\": prompt, \\\"answers\\\": answers}, ensure_ascii=False)+'\\\\n')\\n\",\n    \"print(f\\\"length: {i}, time taken: {time.time()-t2} s\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### baike_qa_2019\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"fs = glob.glob(os.path.join(root, \\\"raw\\\", \\\"baike_qa2019\\\", \\\"baike_qa_*.json\\\"))\\n\",\n    \"fo = os.path.join(root, \\\"chatgpt\\\", \\\"processed\\\", \\\"baike_qa.jsonl\\\")\\n\",\n    \"ct = 0\\n\",\n    \"# items = []\\n\",\n    \"# lens_prompt = []\\n\",\n    \"# lens_label = []\\n\",\n    \"with open(fo, \\\"w\\\", encoding=\\\"utf-8\\\") as w:\\n\",\n    \"    for f in fs:\\n\",\n    \"        with open(f, \\\"r\\\", encoding=\\\"utf-8\\\") as r:\\n\",\n    \"            while True:\\n\",\n    \"                line = r.readline()\\n\",\n    \"                if not line:\\n\",\n    \"                    break\\n\",\n    \"                item = json.loads(line.strip(\\\"\\\\n\\\"))\\n\",\n    \"                question = clean_text(item['title'] if len(item['title']) > len(item['desc']) else item['desc'])\\n\",\n    \"                prompt = f\\\"{question}{tokenizer.sep_token}回答：\\\"\\n\",\n    \"                answer = clean_text(item['answer'])\\n\",\n    \"                answers = [{\\\"answer\\\": answer, \\\"score\\\": 1}]\\n\",\n    \"#                 items.append(item)\\n\",\n    \"#                 lens_prompt.append(len(prompt))\\n\",\n    \"#                 lens_label.append(len(label))\\n\",\n    \"                w.write(json.dumps({\\\"prompt\\\": prompt, \\\"answers\\\": answers}, ensure_ascii=False)+'\\\\n')\\n\",\n    \"                ct += 1\\n\",\n    \"print(ct)\\n\",\n    \"# print(len(items))\\n\",\n    \"# print(np.percentile(lens_prompt, np.arange(90, 101)))\\n\",\n    \"# print(np.percentile(lens_label, np.arange(90, 101)))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### rm-static\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"fi = os.path.join(root, \\\"raw\\\", \\\"rm-static\\\", \\\"data\\\", \\\"test-00000-of-00001-bf4c733542e35fcb.parquet\\\")\\n\",\n    \"df = pd.read_parquet(fi)\\n\",\n    \"print(df.shape)\\n\",\n    \"df.head(2)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# ChatGLM\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import AutoModel\\n\",\n    \"model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True).half().cuda()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"texts = [\\n\",\n    \"    \\\"你好\\\",\\n\",\n    \"    \\\"晚上睡不着应该怎么办\\\"\\n\",\n    \"]\\n\",\n    \"history = []\\n\",\n    \"for text in texts:\\n\",\n    \"    response, history = model.chat(tokenizer, text, history=history)\\n\",\n    \"    print(f\\\"问: {text}\\\\n答:{response}\\\\n\\\")\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.10.9\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "notebook/reward.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"a7d1682d\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import sys\\n\",\n    \"sys.path.insert(0, \\\"/root/autodl-tmp/Code/RLHF\\\")\\n\",\n    \"sys.path.insert(0, \\\"/Users/zeyesun/Documents/Code/RLHF\\\")\\n\",\n    \"sys.path.insert(0, \\\"D:\\\\\\\\Code\\\\\\\\RLHF\\\")\\n\",\n    \"sys.path.insert(0, \\\"/mnt/sfevol775196/sunzeye273/Code/chatgpt\\\")\\n\",\n    \"sys.path.insert(0, \\\"/mnt/share-pa002-vol682688-prd/sunzeye273/Code/chatgpt\\\")\\n\",\n    \"sys.path.insert(0, \\\"/mnt/pa002-28359-vol543625-private/Code/chatgpt\\\")\\n\",\n    \"\\n\",\n    \"import os, time, re, random, glob, json, jieba, copy\\n\",\n    \"import numpy as np\\n\",\n    \"import pandas as pd\\n\",\n    \"from tqdm import tqdm\\n\",\n    \"import torch\\n\",\n    \"from transformers import (\\n\",\n    \"    AutoModelForCausalLM,\\n\",\n    \"    AutoModelForSeq2SeqLM,\\n\",\n    \"    AutoModelForMultipleChoice,\\n\",\n    \"    AutoTokenizer,\\n\",\n    \"    Trainer,\\n\",\n    \"    TrainingArguments,\\n\",\n    \"    default_data_collator,\\n\",\n    \"    TextGenerationPipeline\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"from src.models.reward import RewardModel\\n\",\n    \"\\n\",\n    \"device=\\\"cuda:0\\\" if torch.cuda.is_available() else \\\"cpu\\\"\\n\",\n    \"from sys import platform\\n\",\n    \"if platform == \\\"linux\\\" or platform == \\\"linux2\\\":\\n\",\n    \"    # linux\\n\",\n    \"    root = \\\"/mnt/sfevol775196/sunzeye273/Data\\\"\\n\",\n    \"#     root = \\\"/mnt/share-pa002-vol682688-prd/sunzeye273/Data\\\"\\n\",\n    \"#     root = \\\"/mnt/pa002-28359-vol543625-private/Data\\\"\\n\",\n    \"#     root = \\\"/root/autodl-tmp/Data\\\"\\n\",\n    \"elif platform == \\\"darwin\\\":\\n\",\n    \"    # OS X\\n\",\n    \"    root = \\\"/Users/zeyesun/Documents/Data\\\"\\n\",\n    \"elif platform == \\\"win32\\\":\\n\",\n    \"    # Windows...\\n\",\n    \"    root = \\\"D:\\\\\\\\Data\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"a67e7e4e\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import AutoConfig\\n\",\n    \"model_name = \\\"pangu-small\\\"\\n\",\n    \"model_name_or_path = os.path.join(root, \\\"models\\\", model_name)\\n\",\n    \"\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_cache=False, trust_remote_code=True)\\n\",\n    \"config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)\\n\",\n    \"model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)\\n\",\n    \"f = os.path.join(model_name_or_path, \\\"pytorch_model.bin\\\")\\n\",\n    \"torch.save(model.state_dict(), f)\\n\",\n    \"\\n\",\n    \"model.config.lora_rank = 0\\n\",\n    \"reward_model = RewardModel(model.config, model.transformer, tokenizer)\\n\",\n    \"reward_checkpoint = os.path.join(root, \\\"chatgpt\\\", \\\"output\\\", \\\"reward\\\", model_name, \\\"pytorch_model.bin\\\")\\n\",\n    \"torch.save(reward_model.state_dict(), reward_checkpoint)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"078a3631\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# model_name = \\\"pangu-350M\\\"\\n\",\n    \"model_name = \\\"pangu-small\\\"\\n\",\n    \"# model_name = \\\"chatglm-6B\\\"\\n\",\n    \"model_name_or_path = os.path.join(root, \\\"models\\\", model_name)\\n\",\n    \"\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_cache=False, trust_remote_code=True)\\n\",\n    \"print(tokenizer.special_tokens_map)\\n\",\n    \"print(tokenizer.all_special_ids)\\n\",\n    \"print(\\n\",\n    \"    f\\\"unk: {tokenizer.unk_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"pad: {tokenizer.pad_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"bos: {tokenizer.bos_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"eos: {tokenizer.eos_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"sep: {tokenizer.sep_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"mask: {tokenizer.mask_token_id}\\\\n\\\",\\n\",\n    \"#     f\\\"eop: {tokenizer.eop_token_id}\\\\n\\\"\\n\",\n    \"#     f\\\"sop: {tokenizer.sop_token_id}\\\\n\\\"\\n\",\n    \"#     f\\\"cls: {tokenizer.cls_token_id}\\\"\\n\",\n    \") \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"417b6ada\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"lora_rank = 0\\n\",\n    \"lora_alpha = 1\\n\",\n    \"lora_train_bias = \\\"none\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"753d0550\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if \\\"pangu\\\" in model_name_or_path:\\n\",\n    \"    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, use_cache=False, trust_remote_code=True)\\n\",\n    \"    model.resize_token_embeddings(tokenizer.vocab_size)\\n\",\n    \"    model.config.lora_rank = lora_rank\\n\",\n    \"    model.config.lora_alpha = lora_alpha\\n\",\n    \"    model.config.lora_train_bias = lora_train_bias\\n\",\n    \"    # Initialize the reward model from the (supervised) fine-tuned SFT model\\n\",\n    \"    reward_model = RewardModel(model.config, model.transformer, tokenizer)\\n\",\n    \"elif \\\"chatglm\\\" in model_name_or_path:\\n\",\n    \"    model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, trust_remote_code=True).half()\\n\",\n    \"    model.config.lora_rank = lora_rank\\n\",\n    \"    model.config.lora_alpha = lora_alpha\\n\",\n    \"    model.config.lora_train_bias = lora_train_bias\\n\",\n    \"    # Initialize the reward model from the (supervised) fine-tuned SFT model\\n\",\n    \"    reward_model = RewardModel(model.config, model.transformer, tokenizer).half()\\n\",\n    \"    # reward_model = RewardModelWithLoRA(model.config, model.glm, tokenizer)\\n\",\n    \"#     layers = reward_model.transformer.layers\\n\",\n    \"elif \\\"glm\\\" in model_name_or_path:\\n\",\n    \"    model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, trust_remote_code=True)\\n\",\n    \"    model.config.lora_rank = lora_rank\\n\",\n    \"    model.config.lora_alpha = lora_alpha\\n\",\n    \"    model.config.lora_train_bias = lora_train_bias\\n\",\n    \"    # Initialize the reward model from the (supervised) fine-tuned SFT model\\n\",\n    \"    reward_model = RewardModel(model.config, model.glm, tokenizer)\\n\",\n    \"    # reward_model = RewardModelWithLoRA(model.config, model.glm, tokenizer)\\n\",\n    \"#     layers = reward_model.transformer.transformer.layers\\n\",\n    \"else:\\n\",\n    \"    raise ValueError(f\\\"Unsupported model name: {model_name_or_path}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"3b18dc5e\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# reward_checkpoint = os.path.join(root, \\\"chatgpt\\\", \\\"output\\\", \\\"reward\\\", model_name, \\\"checkpoint-200549\\\", \\\"pytorch_model*.bin\\\")\\n\",\n    \"reward_checkpoint = os.path.join(root, \\\"chatgpt\\\", \\\"output\\\", \\\"reward\\\", model_name, \\\"pytorch_model*.bin\\\")\\n\",\n    \"checkpoints = glob.glob(reward_checkpoint)\\n\",\n    \"st = dict()\\n\",\n    \"for checkpoint in checkpoints:\\n\",\n    \"    st.update(torch.load(checkpoint, map_location=\\\"cpu\\\"))\\n\",\n    \"reward_model.load_state_dict(st)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"800f59c3\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# reward_model = reward_model.half().to(device)\\n\",\n    \"reward_model = reward_model.to(device)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"a68f86ad\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Reward Model Train\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"58669674\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from src.data.data import SFTDataset, PairwiseDataset\\n\",\n    \"from torch.utils.data import DataLoader, SequentialSampler\\n\",\n    \"class dotdict(dict):\\n\",\n    \"    __getattr__ = dict.get\\n\",\n    \"    __setattr__ = dict.__setitem__\\n\",\n    \"    __delattr__ = dict.__delitem__\\n\",\n    \"\\n\",\n    \"max_length = 128\\n\",\n    \"batch_size = 2\\n\",\n    \"args = dotdict({\\\"model_name_or_path\\\": model_name_or_path, \\\"max_length\\\": max_length})\\n\",\n    \"train_filename = os.path.join(root, \\\"chatgpt\\\", \\\"processed\\\", \\\"dev_data_external_v1.jsonl\\\")\\n\",\n    \"train_dataset = PairwiseDataset(args, train_filename, tokenizer)\\n\",\n    \"sampler = SequentialSampler(train_dataset)\\n\",\n    \"train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"bdab3a67\",\n   \"metadata\": {\n    \"scrolled\": false\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"reward_model.train()\\n\",\n    \"for batch in train_loader:\\n\",\n    \"    print(batch.keys())\\n\",\n    \"    chosen_input_ids = batch['chosen_input_ids'].to(device)\\n\",\n    \"    chosen_attention_mask = batch['chosen_attention_mask'].to(device) if 'chosen_attention_mask' in batch else None\\n\",\n    \"    chosen_position_ids = batch['chosen_position_ids'].to(device) if 'chosen_position_ids' in batch else None\\n\",\n    \"    rejected_input_ids = batch['rejected_input_ids'].to(device)\\n\",\n    \"    rejected_attention_mask = batch['rejected_attention_mask'].to(device) if 'rejected_attention_mask' in batch else None\\n\",\n    \"    rejected_position_ids = batch['rejected_position_ids'].to(device) if 'rejected_position_ids' in batch else None\\n\",\n    \"    \\n\",\n    \"    chosen_values, chosen_reward = reward_model.reward(chosen_input_ids, chosen_attention_mask, chosen_position_ids)\\n\",\n    \"    reject_values, reject_reward = reward_model.reward(rejected_input_ids, rejected_attention_mask, rejected_position_ids)\\n\",\n    \"    break\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"44da49d2\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"chosen_end_scores = []\\n\",\n    \"rejected_end_scores = []\\n\",\n    \"bs = chosen_input_ids.shape[0]\\n\",\n    \"if len(chosen_input_ids.shape) == 3:\\n\",\n    \"    chosen_input_ids = chosen_input_ids.squeeze(1)\\n\",\n    \"if len(rejected_input_ids.shape) == 3:\\n\",\n    \"    rejected_input_ids = rejected_input_ids.squeeze(1)\\n\",\n    \"\\n\",\n    \"loss = 0\\n\",\n    \"inference = False\\n\",\n    \"for i in range(bs):\\n\",\n    \"    if torch.all(torch.eq(chosen_input_ids[i], rejected_input_ids[i])).item():\\n\",\n    \"        c_inds = (chosen_input_ids[i] == tokenizer.pad_token_id).nonzero()\\n\",\n    \"        c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen_input_ids.shape[1]\\n\",\n    \"        chosen_end_scores.append(chosen_values[i, c_ind - 1])\\n\",\n    \"        inference = True\\n\",\n    \"        continue\\n\",\n    \"    print(f\\\"inference: {inference}\\\")\\n\",\n    \"\\n\",\n    \"    # Check if there is any padding otherwise take length of sequence\\n\",\n    \"    c_inds = (chosen_input_ids[i] == tokenizer.pad_token_id).nonzero()\\n\",\n    \"    c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen_input_ids.shape[1]\\n\",\n    \"    r_inds = (rejected_input_ids[i] == tokenizer.pad_token_id).nonzero()\\n\",\n    \"    r_ind = r_inds[0].item() if len(r_inds) > 0 else rejected_input_ids.shape[1]\\n\",\n    \"    end_ind = max(c_ind, r_ind)\\n\",\n    \"\\n\",\n    \"    # Retrieve first index where trajectories diverge\\n\",\n    \"    divergence_ind = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()[0]\\n\",\n    \"    assert divergence_ind > 0\\n\",\n    \"\\n\",\n    \"    # Index into the correct rewards\\n\",\n    \"    c_truncated_reward = chosen_values[i][divergence_ind:end_ind]\\n\",\n    \"    r_truncated_reward = reject_values[i][divergence_ind:end_ind]\\n\",\n    \"\\n\",\n    \"    # Append the last rewards to the list of end scores\\n\",\n    \"    chosen_end_scores.append(c_truncated_reward[-1])\\n\",\n    \"    rejected_end_scores.append(r_truncated_reward[-1])\\n\",\n    \"\\n\",\n    \"    # Compute loss\\n\",\n    \"    loss += -torch.log(torch.sigmoid(c_truncated_reward - r_truncated_reward)).mean()\\n\",\n    \"\\n\",\n    \"    loss = loss / bs\\n\",\n    \"chosen_end_scores = torch.stack(chosen_end_scores)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"52095bab\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"loss\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"6b02e60d\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Reward Model Eval\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"70abf467\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"reward_model.eval()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"9f8b4d40\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"max_length = 512\\n\",\n    \"tokenizer.padding_size = \\\"right\\\"\\n\",\n    \"# tokenizer.padding_size = \\\"left\\\"\\n\",\n    \"\\n\",\n    \"prompt = \\\"现代文:行三十五里,进入登封县境的耿店。\\\"\\n\",\n    \"prefix = \\\"古文:\\\"\\n\",\n    \"pred = \\\"<unk><unk><unk><unk><unk><unk><unk><unk><unk><unk>\\\"\\n\",\n    \"\\n\",\n    \"encodings_dict = tokenizer(prompt, prefix+pred, max_length=max_length,\\n\",\n    \"                           truncation=\\\"longest_first\\\", padding=\\\"max_length\\\", return_tensors=\\\"pt\\\",\\n\",\n    \"                           return_token_type_ids=False, padding_side=\\\"left\\\")\\n\",\n    \"print(encodings_dict.keys())\\n\",\n    \"print(encodings_dict['input_ids'].shape)\\n\",\n    \"print(encodings_dict['attention_mask'].shape)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"7d88344b\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"input_ids = encodings_dict['input_ids'].to(device)\\n\",\n    \"attention_mask = encodings_dict['attention_mask'].to(device)\\n\",\n    \"res = reward_model(input_ids, attention_mask)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.10.9\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "notebook/rlhf.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"e0fa8199\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import sys\\n\",\n    \"sys.path.insert(0, \\\"/root/autodl-tmp/Code/RLHF\\\")\\n\",\n    \"sys.path.insert(0, \\\"/Users/zeyesun/Documents/Code/RLHF\\\")\\n\",\n    \"sys.path.insert(0, \\\"D:\\\\\\\\Code\\\\\\\\RLHF\\\")\\n\",\n    \"sys.path.insert(0, \\\"/mnt/sfevol775196/sunzeye273/Code/chatgpt\\\")\\n\",\n    \"sys.path.insert(0, \\\"/mnt/share-pa002-vol682688-prd/sunzeye273/Code/chatgpt\\\")\\n\",\n    \"sys.path.insert(0, \\\"/mnt/pa002-28359-vol543625-private/Code/chatgpt\\\")\\n\",\n    \"\\n\",\n    \"import os, time, re, random, glob, json, jieba, copy\\n\",\n    \"import numpy as np\\n\",\n    \"import pandas as pd\\n\",\n    \"from tqdm import tqdm\\n\",\n    \"import torch\\n\",\n    \"from transformers import (\\n\",\n    \"    AutoModelForCausalLM,\\n\",\n    \"    AutoModelForSeq2SeqLM,\\n\",\n    \"    AutoModelForMultipleChoice,\\n\",\n    \"    AutoTokenizer,\\n\",\n    \"    Trainer,\\n\",\n    \"    TrainingArguments,\\n\",\n    \"    default_data_collator,\\n\",\n    \"    TextGenerationPipeline\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"from src.models.reward import RewardModel\\n\",\n    \"\\n\",\n    \"device=\\\"cuda:0\\\" if torch.cuda.is_available() else \\\"cpu\\\"\\n\",\n    \"from sys import platform\\n\",\n    \"if platform == \\\"linux\\\" or platform == \\\"linux2\\\":\\n\",\n    \"    # linux\\n\",\n    \"    root = \\\"/mnt/sfevol775196/sunzeye273/Data\\\"\\n\",\n    \"#     root = \\\"/mnt/share-pa002-vol682688-prd/sunzeye273/Data\\\"\\n\",\n    \"#     root = \\\"/mnt/pa002-28359-vol543625-private/Data\\\"\\n\",\n    \"#     root = \\\"/root/autodl-tmp/Data\\\"\\n\",\n    \"elif platform == \\\"darwin\\\":\\n\",\n    \"    # OS X\\n\",\n    \"    root = \\\"/Users/zeyesun/Documents/Data\\\"\\n\",\n    \"elif platform == \\\"win32\\\":\\n\",\n    \"    # Windows...\\n\",\n    \"    root = \\\"D:\\\\\\\\Data\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"354bbf38\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# model_name = \\\"pangu-small\\\"\\n\",\n    \"# model_name = \\\"pangu-350M\\\"\\n\",\n    \"model_name = \\\"glm-small\\\"\\n\",\n    \"# model_name = \\\"chatglm-6B\\\"\\n\",\n    \"model_name_or_path = os.path.join(root, \\\"models\\\", model_name)\\n\",\n    \"\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_cache=False, trust_remote_code=True)\\n\",\n    \"print(tokenizer.special_tokens_map)\\n\",\n    \"print(tokenizer.all_special_ids)\\n\",\n    \"print(\\n\",\n    \"    f\\\"unk: {tokenizer.unk_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"pad: {tokenizer.pad_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"bos: {tokenizer.bos_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"eos: {tokenizer.eos_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"sep: {tokenizer.sep_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"mask: {tokenizer.mask_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"eop: {tokenizer.eop_token_id}\\\\n\\\"\\n\",\n    \"#     f\\\"sop: {tokenizer.sop_token_id}\\\\n\\\"\\n\",\n    \"    f\\\"cls: {tokenizer.cls_token_id}\\\"\\n\",\n    \") \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"b2de5b11\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if \\\"glm\\\" in model_name_or_path:\\n\",\n    \"    model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, trust_remote_code=True)\\n\",\n    \"    if \\\"chatglm\\\" in model_name_or_path:\\n\",\n    \"        model = model.half()\\n\",\n    \"else:\\n\",\n    \"    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, use_cache=False)\\n\",\n    \"    model.resize_token_embeddings(tokenizer.vocab_size)\\n\",\n    \"if lora_rank > 0:\\n\",\n    \"    convert_to_lora_recursively(model, lora_rank, lora_alpha)\\n\",\n    \"    lora.mark_only_lora_as_trainable(model, lora_train_bias)\\n\",\n    \"model = model.to(device)\\n\",\n    \"model.eval()\\n\",\n    \"print(model.device)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"904753d8\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"prompt = \\\"你好，你是谁？\\\"\\n\",\n    \"prefix = \\\"答:\\\"\\n\",\n    \"label = \\\"我是ChatGPT\\\"\\n\",\n    \"max_length = 32\\n\",\n    \"max_gen_length = 16\\n\",\n    \"max_prompt_length = max_length - max_gen_length\\n\",\n    \"lora_rank = 0\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"a1484a82\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Dataset Debug\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"469acb98\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from src.data.data import SFTDataset\\n\",\n    \"from torch.utils.data import RandomSampler, DataLoader\\n\",\n    \"class dotdict(dict):\\n\",\n    \"    \\\"\\\"\\\"dot.notation access to dictionary attributes\\\"\\\"\\\"\\n\",\n    \"    __getattr__ = dict.get\\n\",\n    \"    __setattr__ = dict.__setitem__\\n\",\n    \"    __delattr__ = dict.__delitem__\\n\",\n    \"\\n\",\n    \"args = {'model_name_or_path': model_name_or_path,\\n\",\n    \"        \\\"max_length\\\": 128}\\n\",\n    \"args = dotdict(args)\\n\",\n    \"\\n\",\n    \"train_dataset = SFTDataset(args, \\\"/Users/zeyesun/Documents/Data/chatgpt/processed/test_data_external_v1.jsonl\\\", \\n\",\n    \"                              tokenizer)\\n\",\n    \"train_sampler = RandomSampler(train_dataset)\\n\",\n    \"train_dataloader = DataLoader(\\n\",\n    \"        train_dataset,\\n\",\n    \"        sampler=train_sampler,\\n\",\n    \"        batch_size=4)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"14b621de\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Generation Debug\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"1a9d87b5\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"tokenizer.padding_side = \\\"left\\\"\\n\",\n    \"inputs = tokenizer(prompt, \\n\",\n    \"#                    label, \\n\",\n    \"#                    tokenizer.sep_token + prefix,\\n\",\n    \"                   max_length=max_prompt_length,\\n\",\n    \"                   padding=\\\"max_length\\\",\\n\",\n    \"                   truncation=\\\"longest_first\\\", \\n\",\n    \"#                    add_special_tokens=False,\\n\",\n    \"                   return_tensors=\\\"pt\\\", \\n\",\n    \"                   return_token_type_ids=False)\\n\",\n    \"print(inputs)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"ce652f55\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"batch_size, prompt_length = inputs['input_ids'].shape\\n\",\n    \"\\n\",\n    \"with torch.no_grad():\\n\",\n    \"#     logger.debug(f\\\"[_generate_sequence] inputs: {inputs}\\\")\\n\",\n    \"    inputs = {k: v.to(device) for k, v in inputs.items()}\\n\",\n    \"    seq = model.generate(**inputs, \\n\",\n    \"                         max_new_tokens=max_gen_length,\\n\",\n    \"                         pad_token_id=tokenizer.pad_token_id,\\n\",\n    \"                         do_sample=False,\\n\",\n    \"                         num_return_sequences=1,\\n\",\n    \"                         top_p=0.9,\\n\",\n    \"                         temperature=1.0\\n\",\n    \"                        )\\n\",\n    \"print(f\\\"seq: {seq}\\\")\\n\",\n    \"print(tokenizer.batch_decode(seq))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"018d2c5c\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"prompts = []\\n\",\n    \"for i in range(batch_size):\\n\",\n    \"    prompt_ids = seq[i, :prompt_length]\\n\",\n    \"    prompt_start_index = (prompt_ids != tokenizer.pad_token_id).nonzero()[0].item()\\n\",\n    \"    prompt_ids = seq[i, prompt_start_index:prompt_length]\\n\",\n    \"    answer_ids = seq[i, prompt_length:]\\n\",\n    \"    prompt = tokenizer.decode(prompt_ids, skip_special_tokens=False)\\n\",\n    \"    answer = tokenizer.decode(answer_ids, skip_special_tokens=False)\\n\",\n    \"    prompts.append(prompt + answer)\\n\",\n    \"print(prompts)\\n\",\n    \"outputs = tokenizer(prompts, max_length=max_length,\\n\",\n    \"                              truncation=\\\"longest_first\\\", padding=\\\"max_length\\\",\\n\",\n    \"                              return_tensors=\\\"pt\\\", return_token_type_ids=False)\\n\",\n    \"print(outputs)\\n\",\n    \"print(tokenizer.batch_decode(outputs['input_ids']))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"429917ee\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"outputs['input_ids'].device\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"22ca3132\",\n   \"metadata\": {},\n   \"source\": [\n    \"## ChatGLM attention mask and position ids Debug\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"d3308b94\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# chatglm build attention mask\\n\",\n    \"input_ids = inputs['input_ids']\\n\",\n    \"batch_size, seq_length = input_ids.shape\\n\",\n    \"context_lengths = [seq.tolist().index(tokenizer.bos_token_id) for seq in input_ids]\\n\",\n    \"attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)\\n\",\n    \"print(attention_mask.shape)\\n\",\n    \"attention_mask.tril_()\\n\",\n    \"for i, context_length in enumerate(context_lengths):\\n\",\n    \"    attention_mask[i, :, :context_length] = 1\\n\",\n    \"print(attention_mask.shape)\\n\",\n    \"attention_mask.unsqueeze_(1)\\n\",\n    \"print(attention_mask.shape)\\n\",\n    \"# attention_mask = (attention_mask < 0.5).bool()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"8c4fd4e6\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# chatglm bulid position ids\\n\",\n    \"batch_size, seq_length = input_ids.shape\\n\",\n    \"context_lengths = [seq.tolist().index(tokenizer.bos_token_id) for seq in input_ids]\\n\",\n    \"# if self.position_encoding_2d:\\n\",\n    \"position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)\\n\",\n    \"# if not gmask:\\n\",\n    \"#     for i, context_length in enumerate(context_lengths):\\n\",\n    \"#         position_ids[i, context_length:] = mask_positions[i]\\n\",\n    \"block_position_ids = [torch.cat((\\n\",\n    \"    torch.zeros(context_length, dtype=torch.long, device=device),\\n\",\n    \"    torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1\\n\",\n    \")) for context_length in context_lengths]\\n\",\n    \"block_position_ids = torch.stack(block_position_ids, dim=0)\\n\",\n    \"position_ids = torch.stack((position_ids, block_position_ids), dim=1)\\n\",\n    \"# else:\\n\",\n    \"#     position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)\\n\",\n    \"#     if not gmask:\\n\",\n    \"#         for i, context_length in enumerate(context_lengths):\\n\",\n    \"#             position_ids[context_length:] = mask_positions[i]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"b6b29e22\",\n   \"metadata\": {},\n   \"source\": [\n    \"## GLM attention mask and position ids Debug\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"b12dda63\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"encoded_prompt = tokenizer(prompt, prefix + tokenizer.mask_token)\\n\",\n    \"prompt_length = len(encoded_prompt['input_ids'])\\n\",\n    \"label_length = len(tokenizer.tokenize(label)) + 1\\n\",\n    \"if prompt_length + label_length > max_length:\\n\",\n    \"    num_tokens_to_remove = prompt_length + label_length - max_length\\n\",\n    \"    for _ in range(num_tokens_to_remove):\\n\",\n    \"        if prompt_length > label_length:\\n\",\n    \"            prompt_length -= 1\\n\",\n    \"        else:\\n\",\n    \"            label_length -= 1\\n\",\n    \"else:\\n\",\n    \"    label_length = max_length - prompt_length\\n\",\n    \"assert prompt_length > 0\\n\",\n    \"assert label_length > 0\\n\",\n    \"assert prompt_length + label_length <= max_length\\n\",\n    \"inputs = tokenizer(prompt, prefix + tokenizer.mask_token,\\n\",\n    \"                              max_length=prompt_length,\\n\",\n    \"                              truncation=\\\"only_first\\\",\\n\",\n    \"                              return_tensors=\\\"pt\\\",\\n\",\n    \"                              return_attention_mask=True,\\n\",\n    \"                              return_token_type_ids=False)\\n\",\n    \"inputs_glm = tokenizer.build_inputs_for_generation(inputs, targets=label,\\n\",\n    \"                                                          max_gen_length=label_length, padding=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"6521cafd\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"tokenizer.padding_side = \\\"left\\\"\\n\",\n    \"encoded_prompt = tokenizer(prompt, prefix + tokenizer.mask_token)\\n\",\n    \"prompt_length = len(encoded_prompt['input_ids'])\\n\",\n    \"inputs = tokenizer(prompt, prefix + tokenizer.mask_token,\\n\",\n    \"                              max_length=max_prompt_length,\\n\",\n    \"                              padding=\\\"max_length\\\",\\n\",\n    \"                              truncation=\\\"only_first\\\",\\n\",\n    \"                              return_tensors=\\\"pt\\\",\\n\",\n    \"                              return_token_type_ids=False)\\n\",\n    \"# max_gen_length = max_length - prompt_length\\n\",\n    \"# max_gen_length = self.args.max_gen_length\\n\",\n    \"assert prompt_length > 0\\n\",\n    \"assert max_gen_length > 0\\n\",\n    \"assert prompt_length + max_gen_length <= max_length\\n\",\n    \"inputs_glm = tokenizer.build_inputs_for_generation(inputs, max_gen_length=max_gen_length, padding=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"d6344ed6\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"print(inputs)\\n\",\n    \"print(inputs['input_ids'].shape)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"2fb5447b\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"print(inputs_glm)\\n\",\n    \"print(inputs_glm['input_ids'].shape)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"f7edfa03\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.10.9\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "notebook/sft.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"6f526208\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import sys\\n\",\n    \"sys.path.insert(0, \\\"/root/autodl-tmp/Code/RLHF\\\")\\n\",\n    \"sys.path.insert(0, \\\"/Users/zeyesun/Documents/Code/RLHF\\\")\\n\",\n    \"sys.path.insert(0, \\\"D:\\\\\\\\Code\\\\\\\\RLHF\\\")\\n\",\n    \"sys.path.insert(0, \\\"/mnt/sfevol775196/sunzeye273/Code/chatgpt\\\")\\n\",\n    \"sys.path.insert(0, \\\"/mnt/share-pa002-vol682688-prd/sunzeye273/Code/chatgpt\\\")\\n\",\n    \"sys.path.insert(0, \\\"/mnt/pa002-28359-vol543625-private/Code/chatgpt\\\")\\n\",\n    \"\\n\",\n    \"import os, time, re, random, glob, json, jieba, copy\\n\",\n    \"import numpy as np\\n\",\n    \"import pandas as pd\\n\",\n    \"from tqdm import tqdm\\n\",\n    \"import torch\\n\",\n    \"import loralib as lora\\n\",\n    \"from transformers import (\\n\",\n    \"    AutoModelForCausalLM,\\n\",\n    \"    AutoModelForSeq2SeqLM,\\n\",\n    \"    AutoModelForMultipleChoice,\\n\",\n    \"    AutoTokenizer,\\n\",\n    \"    Trainer,\\n\",\n    \"    TrainingArguments,\\n\",\n    \"    default_data_collator,\\n\",\n    \"    TextGenerationPipeline\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"from src.models.lora import convert_to_lora_recursively\\n\",\n    \"\\n\",\n    \"device=\\\"cuda:0\\\" if torch.cuda.is_available() else \\\"cpu\\\"\\n\",\n    \"from sys import platform\\n\",\n    \"if platform == \\\"linux\\\" or platform == \\\"linux2\\\":\\n\",\n    \"    # linux\\n\",\n    \"    root = \\\"/mnt/sfevol775196/sunzeye273/Data\\\"\\n\",\n    \"#     root = \\\"/mnt/share-pa002-vol682688-prd/sunzeye273/Data\\\"\\n\",\n    \"#     root = \\\"/mnt/pa002-28359-vol543625-private/Data\\\"\\n\",\n    \"#     root = \\\"/root/autodl-tmp/Data\\\"\\n\",\n    \"elif platform == \\\"darwin\\\":\\n\",\n    \"    # OS X\\n\",\n    \"    root = \\\"/Users/zeyesun/Documents/Data\\\"\\n\",\n    \"elif platform == \\\"win32\\\":\\n\",\n    \"    # Windows...\\n\",\n    \"    root = \\\"D:\\\\\\\\Data\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"633e69a5\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# model_name = \\\"pangu-350M\\\"\\n\",\n    \"# model_name = \\\"glm-350M-chinese\\\"\\n\",\n    \"model_name = \\\"chatglm-6B\\\"\\n\",\n    \"model_name_or_path = os.path.join(root, \\\"models\\\", model_name)\\n\",\n    \"# model_name_or_path = os.path.join(\\\"/mnt/pa002-28359-vol543625-share/LLM-data/checkpoint\\\", model_name)\\n\",\n    \"\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_cache=False, trust_remote_code=True)\\n\",\n    \"print(tokenizer.special_tokens_map)\\n\",\n    \"print(tokenizer.all_special_ids)\\n\",\n    \"print(\\n\",\n    \"    f\\\"unk: {tokenizer.unk_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"pad: {tokenizer.pad_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"bos: {tokenizer.bos_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"eos: {tokenizer.eos_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"sep: {tokenizer.sep_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"mask: {tokenizer.mask_token_id}\\\\n\\\",\\n\",\n    \"#     f\\\"eop: {tokenizer.eop_token_id}\\\\n\\\"\\n\",\n    \"#     f\\\"sop: {tokenizer.sop_token_id}\\\\n\\\"\\n\",\n    \"#     f\\\"cls: {tokenizer.cls_token_id}\\\"\\n\",\n    \") \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"7c62c599\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"lora_rank = 0\\n\",\n    \"lora_alpha = 1\\n\",\n    \"lora_train_bias = \\\"none\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"b2cae21a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if \\\"glm\\\" in model_name_or_path:\\n\",\n    \"    model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, trust_remote_code=True)\\n\",\n    \"    if \\\"chatglm\\\" in model_name_or_path:\\n\",\n    \"        model = model.half()\\n\",\n    \"else:\\n\",\n    \"    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, use_cache=False)\\n\",\n    \"    model.resize_token_embeddings(tokenizer.vocab_size)\\n\",\n    \"if lora_rank > 0:\\n\",\n    \"    convert_to_lora_recursively(model, lora_rank, lora_alpha)\\n\",\n    \"    lora.mark_only_lora_as_trainable(model, lora_train_bias)\\n\",\n    \"model.to(device)\\n\",\n    \"print(model.device)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"8a28f07a\",\n   \"metadata\": {},\n   \"source\": [\n    \"# SFT Prediction\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"b003a1c8\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"checkpoint_files = os.path.join(root, \\\"chatgpt\\\", \\\"output\\\", \\\"sft\\\", \\\"pangu-350M\\\", \\\"checkpoint-57043\\\", \\\"pytorch_model*.bin\\\")\\n\",\n    \"# checkpoint_files = os.path.join(root, \\\"chatgpt\\\", \\\"output\\\", \\\"sft\\\", \\\"pangu-2.6B\\\", \\\"pangu-2.6B-sft\\\", \\\"pytorch_model*.bin\\\")\\n\",\n    \"checkpoints = glob.glob(checkpoint_files)\\n\",\n    \"st = dict()\\n\",\n    \"for checkpoint in checkpoints:\\n\",\n    \"    st.update(torch.load(checkpoint, map_location=\\\"cpu\\\"))\\n\",\n    \"model.load_state_dict(st)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"9330b5aa\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"max_length = 512\\n\",\n    \"max_length_generation = 50\\n\",\n    \"num_return_sequences = 1\\n\",\n    \"top_p = 0.8\\n\",\n    \"top_k = 50\\n\",\n    \"temperature = 5.0\\n\",\n    \"# prompt = '今天晚上我在睡觉.........他想要做那些事..我就大大声骂他\\\"不要吵我睡觉\\\"!!!!!...他就跑出去了...还不接我电话'\\n\",\n    \"# prompt = \\\"\\\"\\\"阅读文章：《战国无双3》（）是由光荣和ω-force开发的战国无双系列的正统第三续作。本作以三大故事为主轴，分别是以武田信玄等人为主的《关东三国志》，织田信长等人为主的《战国三杰》，石田三成等人为主的《关原的年轻武者》，丰富游戏内的剧情。此部份专门介绍角色，欲知武器情报、奥义字或擅长攻击类型等，请至战国无双系列1.由于乡里大辅先生因故去世，不得不寻找其他声优接手。从猛将传 and Z开始。2.战国无双 编年史的原创男女主角亦有专属声优。此模式是任天堂游戏谜之村雨城改编的新增模式。本作中共有20张战场地图（不含村雨城），后来发行的猛将传再新增3张战场地图。但游戏内战役数量繁多，部分地图会有兼用的状况，战役虚实则是以光荣发行的2本「战国无双3 人物真书」内容为主，以下是相关介绍。（注：前方加☆者为猛将传新增关卡及地图。）合并本篇和猛将传的内容，村雨城模式剔除，战国史模式可直接游玩。主打两大模式「战史演武」&「争霸演武」。系列作品外传作品\\\\n问：《战国无双3》是由哪两个公司合作开发的？\\\"\\\"\\\"\\n\",\n    \"# prefix = \\\"答：\\\"\\n\",\n    \"\\n\",\n    \"prompt = \\\"以王维的风格,写一首五言诗,题为和宋中丞夏日游福贤观天长寺寺即陈左相宅所施之作\\\"\\n\",\n    \"prefix = \\\"\\\"\\n\",\n    \"\\n\",\n    \"# prompt = \\\"上联：东风执笔点龙睛，看幸福指数，天天向上\\\"\\n\",\n    \"# prefix = \\\"下联：\\\"\\n\",\n    \"\\n\",\n    \"prefix = \\\"\\\"\\n\",\n    \"while True:\\n\",\n    \"    if \\\"chatglm\\\" in model_name_or_path:\\n\",\n    \"        encoded_prompt = tokenizer(prompt)\\n\",\n    \"        prompt_length = len(encoded_prompt['input_ids'])\\n\",\n    \"        inputs = tokenizer(prompt, max_length=min(prompt_length, max_length), truncation=\\\"only_first\\\",\\n\",\n    \"                   return_tensors=\\\"pt\\\")\\n\",\n    \"        inputs = inputs.to(device)\\n\",\n    \"        outputs = model.generate(**inputs,\\n\",\n    \"                                 max_new_tokens=max_length_generation,\\n\",\n    \"                                 eos_token_id=tokenizer.eop_token_id,\\n\",\n    \"                                 pad_token_id=tokenizer.pad_token_id,\\n\",\n    \"                                 do_sample=False,\\n\",\n    \"                                 num_return_sequences=num_return_sequences,\\n\",\n    \"                                 top_p=top_p,\\n\",\n    \"                                 temperature=temperature)\\n\",\n    \"    elif \\\"glm\\\" in model_name_or_path:\\n\",\n    \"        encoded_prompt = tokenizer(prompt, prefix + tokenizer.mask_token)\\n\",\n    \"        prompt_length = len(encoded_prompt['input_ids'])\\n\",\n    \"        inputs = tokenizer(prompt, prefix + tokenizer.mask_token,\\n\",\n    \"#                            max_length=max_length - max_length_generation,\\n\",\n    \"#                            padding=\\\"max_length\\\",\\n\",\n    \"                           max_length=min(prompt_length, max_length),\\n\",\n    \"                           truncation=\\\"only_first\\\",\\n\",\n    \"                           return_tensors=\\\"pt\\\",\\n\",\n    \"                           return_attention_mask=True,\\n\",\n    \"                           return_token_type_ids=False)\\n\",\n    \"        max_length_generation = max_length - inputs['input_ids'].shape[1]\\n\",\n    \"        inputs_glm = tokenizer.build_inputs_for_generation(inputs, \\n\",\n    \"                                                           max_gen_length=max_length_generation, padding=True)\\n\",\n    \"        inputs_glm = inputs_glm.to(device)\\n\",\n    \"        outputs = model.generate(**inputs_glm,\\n\",\n    \"                                 max_new_tokens=max_length_generation,\\n\",\n    \"                                 eos_token_id=tokenizer.eop_token_id,\\n\",\n    \"                                 pad_token_id=tokenizer.pad_token_id,\\n\",\n    \"                                 do_sample=False,\\n\",\n    \"                                 num_return_sequences=num_return_sequences,\\n\",\n    \"                                 top_p=top_p,\\n\",\n    \"                                 temperature=temperature)\\n\",\n    \"    else:\\n\",\n    \"        inputs = tokenizer(prompt, tokenizer.sep_token + prefix, \\n\",\n    \"                           max_length=max_length, \\n\",\n    \"                           return_tensors=\\\"pt\\\",\\n\",\n    \"                           truncation=\\\"only_first\\\", \\n\",\n    \"#                            padding=\\\"max_length\\\",\\n\",\n    \"                           add_special_tokens=False,\\n\",\n    \"                           return_token_type_ids=False)\\n\",\n    \"        inputs = inputs.to(device)\\n\",\n    \"        outputs = model.generate(**inputs,\\n\",\n    \"                                 max_new_tokens=max_length_generation,\\n\",\n    \"                                 pad_token_id=tokenizer.pad_token_id,\\n\",\n    \"                                 do_sample=True,\\n\",\n    \"                                 num_return_sequences=num_return_sequences,\\n\",\n    \"#                                  top_p=top_p,\\n\",\n    \"                                 top_k=top_k,\\n\",\n    \"                                 temperature=temperature)\\n\",\n    \"    results = tokenizer.batch_decode(outputs, skip_special_tokens=True)\\n\",\n    \"#     results = [result.split(prefix, maxsplit=1)[1] for result in results]\\n\",\n    \"    print(results)\\n\",\n    \"    break\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"9dbb3482\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# f = os.path.join(root, \\\"raw\\\", \\\"baike_qa_train.json\\\")\\n\",\n    \"# [baike_qa.jsonl, chinese_classical.jsonl, chinese_poetry.jsonl, couplets.jsonl, weibo_summary_comments.jsonl, zhidao.jsonl]\\n\",\n    \"f = os.path.join(root, \\\"chatgpt\\\", \\\"processed\\\", \\\"baike_qa.jsonl\\\")\\n\",\n    \"i = 0\\n\",\n    \"prompts = []\\n\",\n    \"prompts_processed = []\\n\",\n    \"labels = []\\n\",\n    \"with open(f, \\\"r\\\", encoding=\\\"utf-8\\\") as r:\\n\",\n    \"    while True:\\n\",\n    \"        line = r.readline()\\n\",\n    \"        if not line:\\n\",\n    \"            break\\n\",\n    \"        item = json.loads(line.strip(\\\"\\\\n\\\"))\\n\",\n    \"        # prompt = clean_text(item['title'] if len(item['title']) > len(item['desc']) else item['desc'])\\n\",\n    \"        # prompt_processed = prompt + tokenizer.sep_token + prefix\\n\",\n    \"        # label = clean_text(item['answer'])\\n\",\n    \"        prompt = item['prompt']\\n\",\n    \"        prompt_processed = prompt\\n\",\n    \"        label = item['answers'][0]['answer']\\n\",\n    \"        prompts.append(prompt)\\n\",\n    \"        prompts_processed.append(prompt_processed)\\n\",\n    \"        labels.append(label)\\n\",\n    \"        i += 1\\n\",\n    \"        # if i > 1000:\\n\",\n    \"        #     break\\n\",\n    \"random.shuffle(prompts_processed)\\n\",\n    \"print(len(prompts_processed))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"0ad64148\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"i = 79\\n\",\n    \"num_return_sequences = 2\\n\",\n    \"max_length = 512\\n\",\n    \"max_length_generation = 100\\n\",\n    \"top_k = 50\\n\",\n    \"top_p = 0.8\\n\",\n    \"temperature = 1.0\\n\",\n    \"t1 = time.time()\\n\",\n    \"prompt = prompts_processed[i]\\n\",\n    \"inputs = tokenizer(prompt, add_special_tokens=False, return_token_type_ids=False, return_tensors=\\\"pt\\\")\\n\",\n    \"inputs = inputs.to(device)\\n\",\n    \"outputs = model.generate(**inputs,\\n\",\n    \"                         max_new_tokens=max_length_generation,\\n\",\n    \"                         pad_token_id=tokenizer.pad_token_id,\\n\",\n    \"                         do_sample=True,\\n\",\n    \"                         num_return_sequences=num_return_sequences,\\n\",\n    \"                         # top_p=top_p,\\n\",\n    \"                         top_k=top_k,\\n\",\n    \"                         temperature=temperature)\\n\",\n    \"results = tokenizer.batch_decode(outputs, skip_special_tokens=True)\\n\",\n    \"for result in results:\\n\",\n    \"    # result.split(prefix, maxsplit=1)[1]\\n\",\n    \"    print(result)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# results = text_generator(prompts_processed[i:j], max_length=200, num_return_sequences=num_return_sequences,\\n\",\n    \"#                          do_sample=True, top_k=50, temperature=10.0)\\n\",\n    \"# print(f\\\"Finished prediction, time taken: {time.time()-t1}\\\")\\n\",\n    \"\\n\",\n    \"# for prompt, res, label in zip(prompts[i:j], results[:(j-i)], labels[i:j]):\\n\",\n    \"#     print(f\\\"prompt: {prompt}\\\\nlabel: {label}\\\")\\n\",\n    \"#     for k in range(num_return_sequences):\\n\",\n    \"#         model_answer = res[k]['generated_text'].split(prefix)[1].replace(\\\"<eot>\\\", \\\"\\\").replace(\\\"<pad>\\\", \\\"\\\")\\n\",\n    \"#         print(f\\\"model answer-{k}: {model_answer}\\\")\\n\",\n    \"#     print(\\\"\\\\n\\\\n\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"da1e08ce\",\n   \"metadata\": {},\n   \"source\": [\n    \"# SFT Training\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"9eef5821\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# prompt = \\\"你是谁\\\"\\n\",\n    \"# prefix = \\\"答:\\\"\\n\",\n    \"# label = \\\"我是***，很高兴为你服务\\\"\\n\",\n    \"prompt = \\\"\\\"\\\"倍数金额： 1倍，￥1024 元 场次 主---------客队 投注选项- 参考赔率---投注选项胜 平 负---\\n\",\n    \"第一次为基本面投注---第二次为通过处理后投注  1 伯明翰 VS -----维冈 31-----1.93 3.27 3.87 ---  2伯恩利VS---朴茨茅30----- ---3 博尔顿 VS -----狼队\\n\",\n    \"3------1.94 3.25 3.88 ---  4 斯托克 VS ---阿森纳 0------5.03 3.47 1.68 ---  5 门兴 VS -----弗赖堡 31-----1.77 3.倍数金额： 1倍，￥1024 元 场次\\n\",\n    \"主---------客队 投注选项- 参考赔率---投注选项胜 平 负--- 第一次为基本面投注---第二次为通过处理后投注  1 伯明翰 VS -----维冈 31-----1.93 3.27 3.87 ---\\n\",\n    \"2伯恩利VS---朴茨茅30----- ---3 博尔顿 VS -----狼队 3------1.94 3.25 3.88 ---  4 斯托克 VS ---阿森纳 0------5.03 3.47 1.68 ---  5 门兴 VS -----弗赖堡\\n\",\n    \"31-----1.77 3.39 4.43 ---  6 美因兹 VS ---不来梅 10-----3.76 3.34 1.92 ---  7波鸿VS-----纽伦堡30----- ---8 斯图加 VS ---法兰克 31-----1.59 3.62 5.47\\n\",\n    \"---  9 赫塔 VS -----霍芬海 30-----2.49 3.19 2.69 ---  10 勒沃 VS ------科隆 3------1.35 4.44 8.31 ---  11卡塔尼VS----巴里31----- ---12 拉齐奥 VS\\n\",\n    \"--佛罗伦 31-----2.35 3.05 3.01 ---  13 特内里 VS ----皇马 0------9.43 4.95 1.29 ---  14 巴萨 VS ----马拉加 3------1.15 6.78 15.49 --\\\"\\\"\\\"\\n\",\n    \"prefix = \\\"回答：\\\"\\n\",\n    \"label = \\\"你出的赔率数据太早了，数据随时都会变化，这就是所谓要看临盘的道理，目前的数据没什么参考价值。\\\"\\n\",\n    \"max_length = 512\\n\",\n    \"encoded_prompt = tokenizer(prompt, prefix + tokenizer.mask_token)\\n\",\n    \"prompt_length = len(encoded_prompt['input_ids'])\\n\",\n    \"label_length = len(tokenizer.tokenize(label)) + (1 if \\\"chatglm\\\" not in model_name_or_path else 0)\\n\",\n    \"# print(f\\\"prompt length: {prompt_length}, label length: {label_length}\\\")\\n\",\n    \"if prompt_length + label_length > max_length:\\n\",\n    \"    num_tokens_to_remove = prompt_length + label_length - max_length\\n\",\n    \"    for _ in range(num_tokens_to_remove):\\n\",\n    \"        if prompt_length > label_length:\\n\",\n    \"            prompt_length -= 1\\n\",\n    \"        else:\\n\",\n    \"            label_length -= 1\\n\",\n    \"else:\\n\",\n    \"    label_length = max_length - prompt_length\\n\",\n    \"assert prompt_length > 0\\n\",\n    \"assert label_length > 0\\n\",\n    \"assert prompt_length + label_length <= max_length\\n\",\n    \"encoded_dict = tokenizer(prompt, prefix + tokenizer.mask_token,\\n\",\n    \"                         max_length=prompt_length, truncation=\\\"only_first\\\",\\n\",\n    \"                         return_tensors=\\\"pt\\\", return_attention_mask=True)\\n\",\n    \"inputs = tokenizer.build_inputs_for_generation(encoded_dict, targets=label,\\n\",\n    \"                                               max_gen_length=label_length, padding=True)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"5d8b9e45\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"print(tokenizer.special_tokens_map)\\n\",\n    \"print(tokenizer.all_special_ids)\\n\",\n    \"print(\\n\",\n    \"    f\\\"unk: {tokenizer.unk_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"pad: {tokenizer.pad_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"bos: {tokenizer.bos_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"eos: {tokenizer.eos_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"sep: {tokenizer.sep_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"mask: {tokenizer.mask_token_id}\\\\n\\\",\\n\",\n    \"#     f\\\"eop: {tokenizer.eop_token_id}\\\\n\\\"\\n\",\n    \"#     f\\\"sop: {tokenizer.sop_token_id}\\\\n\\\"\\n\",\n    \"#     f\\\"cls: {tokenizer.cls_token_id}\\\"\\n\",\n    \") \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"f99c82e8\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"print(tokenizer.convert_ids_to_tokens([20006]))\\n\",\n    \"print(tokenizer.convert_ids_to_tokens([20012]))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"24d349a5\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"print(encoded_prompt)\\n\",\n    \"print(tokenizer.decode(encoded_prompt['input_ids']))\\n\",\n    \"print(encoded_dict)\\n\",\n    \"print(tokenizer.batch_decode(encoded_dict['input_ids']))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"9a63f417\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"for key, val in inputs.items():\\n\",\n    \"    print(f\\\"{key} shape: {val.shape}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"7f78df96\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"for key, val in inputs_glm.items():\\n\",\n    \"    print(f\\\"{key} shape: {val.shape}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"4261c76d\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"print(inputs_glm['input_ids'][:20])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"6667eaac\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"print(inputs_glm['labels'][:20])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"ac236f19\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"print(inputs_glm['attention_mask'][0][9])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"bda9e8ba\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"print(inputs_glm['position_ids'][0][:20])\\n\",\n    \"print(inputs_glm['position_ids'][1][:20])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"83126e24\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# st = model.state_dict()\\n\",\n    \"st.keys()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"7320e1ca\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"print(st['transformer.word_embeddings.weight'].dtype)\\n\",\n    \"print(st['transformer.layers.0.input_layernorm.weight'].dtype)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"81289f7f\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dtypes = dict()\\n\",\n    \"for key, val in st.items():\\n\",\n    \"    if val.dtype not in dtypes:\\n\",\n    \"        dtypes[val.dtype] = list()\\n\",\n    \"    dtypes[val.dtype].append(key)\\n\",\n    \"print(dtypes.keys())\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"3f4e3327\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"output = model(**inputs)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"fcc223bf\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.10.9\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "notebook/test.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"93c83ce9\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os, time, re, random, glob, json, jieba, copy\\n\",\n    \"import numpy as np\\n\",\n    \"import pandas as pd\\n\",\n    \"from tqdm import tqdm\\n\",\n    \"import torch\\n\",\n    \"from transformers import (\\n\",\n    \"    AutoConfig,\\n\",\n    \"    AutoModelForCausalLM,\\n\",\n    \"    AutoModelForSeq2SeqLM,\\n\",\n    \"    AutoModelForMultipleChoice,\\n\",\n    \"    AutoTokenizer,\\n\",\n    \"    Trainer,\\n\",\n    \"    TrainingArguments,\\n\",\n    \"    default_data_collator\\n\",\n    \")\\n\",\n    \"device=\\\"cuda:0\\\" if torch.cuda.is_available() else \\\"cpu\\\"\\n\",\n    \"from sys import platform\\n\",\n    \"if platform == \\\"linux\\\" or platform == \\\"linux2\\\":\\n\",\n    \"    # linux\\n\",\n    \"    root = \\\"/mnt/sfevol775196/sunzeye273/Data\\\"\\n\",\n    \"#     root = \\\"/mnt/share-pa002-vol682688-prd/sunzeye273/Data\\\"\\n\",\n    \"#     root = \\\"/mnt/pa002-28359-vol543625-private/Data\\\"\\n\",\n    \"#     root = \\\"/root/autodl-tmp/Data\\\"\\n\",\n    \"elif platform == \\\"darwin\\\":\\n\",\n    \"    # OS X\\n\",\n    \"    root = \\\"/Users/zeyesun/Documents/Data\\\"\\n\",\n    \"elif platform == \\\"win32\\\":\\n\",\n    \"    # Windows...\\n\",\n    \"    root = \\\"D:\\\\\\\\Data\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"3d6fdca2\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_name = \\\"llama-7b\\\"\\n\",\n    \"# model_name = \\\"glm-350M-chinese\\\"\\n\",\n    \"# model_name = \\\"chatglm-6B\\\"\\n\",\n    \"model_name_or_path = os.path.join(root, \\\"models\\\", model_name)\\n\",\n    \"\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_cache=False, trust_remote_code=True)\\n\",\n    \"print(tokenizer.special_tokens_map)\\n\",\n    \"print(tokenizer.all_special_ids)\\n\",\n    \"print(\\n\",\n    \"    f\\\"unk: {tokenizer.unk_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"pad: {tokenizer.pad_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"bos: {tokenizer.bos_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"eos: {tokenizer.eos_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"sep: {tokenizer.sep_token_id}\\\\n\\\",\\n\",\n    \"    f\\\"mask: {tokenizer.mask_token_id}\\\\n\\\",\\n\",\n    \"#     f\\\"eop: {tokenizer.eop_token_id}\\\\n\\\"\\n\",\n    \"#     f\\\"sop: {tokenizer.sop_token_id}\\\\n\\\"\\n\",\n    \"    f\\\"cls: {tokenizer.cls_token_id}\\\"\\n\",\n    \") \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"id\": \"52c6acb9\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"2916.6666666666665\"\n      ]\n     },\n     \"execution_count\": 2,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"35000/ (3*4)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"df9a2572\",\n   \"metadata\": {},\n   \"source\": [\n    \"# LLaMa\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"6ccd5025\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_name = \\\"llama-small\\\"\\n\",\n    \"model_name_or_path = os.path.join(root, \\\"models\\\", model_name)\\n\",\n    \"# tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_cache=False, trust_remote_code=True)\\n\",\n    \"config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)\\n\",\n    \"model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"d76f86cd\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"f = os.path.join(model_name_or_path, \\\"pytorch_model.bin\\\")\\n\",\n    \"torch.save(model.state_dict(), f)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"5405693b\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"prompt = \\\"你好 </s> 你是谁 </s> 你好 </s> 你是谁\\\"\\n\",\n    \"eos_ids = [0, 5, 12, 16, 22]\\n\",\n    \"max_length = 512\\n\",\n    \"input_embeds = torch.rand(1, dtype=torch.float16)\\n\",\n    \"# tokenizer.tokenize(prompt, add_special_tokens=False)\\n\",\n    \"inputs = tokenizer(prompt,  max_length=max_length, return_tensors=\\\"pt\\\", truncation=\\\"longest_first\\\", \\n\",\n    \"          return_attention_mask=True, return_token_type_ids=False)\\n\",\n    \"# print(inputs)\\n\",\n    \"tgt_len = len(inputs['input_ids'][0])\\n\",\n    \"print(f\\\"tgt_len: {tgt_len}\\\")\\n\",\n    \"combined_attention_mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(input_embeds.dtype).min))\\n\",\n    \"# print(f\\\"shape: {combined_attention_mask.shape}, combined_attention_mask: {combined_attention_mask}\\\")\\n\",\n    \"for i in range(len(eos_ids)-1):\\n\",\n    \"    attention_mask = torch.ones((1, eos_ids[i+1]-eos_ids[i]), dtype=torch.long)\\n\",\n    \"#     print(f\\\"{i}, shape: {attention_mask.shape}, attention_mask: {attention_mask}\\\")\\n\",\n    \"    attention_mask = _prepare_decoder_attention_mask(attention_mask, attention_mask.shape, \\n\",\n    \"                                    input_embeds, 0)\\n\",\n    \"    combined_attention_mask[eos_ids[i]:eos_ids[i+1], eos_ids[i]:eos_ids[i+1]] = attention_mask\\n\",\n    \"#     print(f\\\"{i}-th, shape: {attention_mask.shape}, attention_mask: {attention_mask}\\\")\\n\",\n    \"#     break\\n\",\n    \"# print(f\\\"shape: {combined_attention_mask.shape}, combined_attention_mask: {combined_attention_mask}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"317fe60a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def _make_causal_mask(\\n\",\n    \"    input_ids_shape, dtype, device, past_key_values_length = 0\\n\",\n    \"):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    Make causal mask used for bi-directional self-attention.\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    bsz, tgt_len = input_ids_shape\\n\",\n    \"    mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)\\n\",\n    \"    mask_cond = torch.arange(mask.size(-1), device=device)\\n\",\n    \"    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)\\n\",\n    \"    mask = mask.to(dtype)\\n\",\n    \"\\n\",\n    \"    if past_key_values_length > 0:\\n\",\n    \"        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)\\n\",\n    \"    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# Copied from transformers.models.bart.modeling_bart._expand_mask\\n\",\n    \"def _expand_mask(mask, dtype, tgt_len):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    bsz, src_len = mask.size()\\n\",\n    \"    tgt_len = tgt_len if tgt_len is not None else src_len\\n\",\n    \"\\n\",\n    \"    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\\n\",\n    \"\\n\",\n    \"    inverted_mask = 1.0 - expanded_mask\\n\",\n    \"\\n\",\n    \"    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)\\n\",\n    \"\\n\",\n    \"def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length):\\n\",\n    \"        # create causal mask\\n\",\n    \"        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\\n\",\n    \"        combined_attention_mask = None\\n\",\n    \"        if input_shape[-1] > 1:\\n\",\n    \"            combined_attention_mask = _make_causal_mask(\\n\",\n    \"                input_shape,\\n\",\n    \"                inputs_embeds.dtype,\\n\",\n    \"                device=inputs_embeds.device,\\n\",\n    \"                past_key_values_length=past_key_values_length,\\n\",\n    \"            )\\n\",\n    \"\\n\",\n    \"        if attention_mask is not None:\\n\",\n    \"            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\\n\",\n    \"            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(\\n\",\n    \"                inputs_embeds.device\\n\",\n    \"            )\\n\",\n    \"            combined_attention_mask = (\\n\",\n    \"                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask\\n\",\n    \"            )\\n\",\n    \"\\n\",\n    \"        return combined_attention_mask\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f589231f\",\n   \"metadata\": {},\n   \"source\": [\n    \"# GLM\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"a7268832\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"prompt = \\\"你好\\\"\\n\",\n    \"prefix = \\\"答:\\\"\\n\",\n    \"max_length = 128\\n\",\n    \"encoded_prompt = tokenizer(prompt, prefix + tokenizer.mask_token)\\n\",\n    \"prompt_length = len(encoded_prompt['input_ids'])\\n\",\n    \"encoded_dict = tokenizer(prompt, prefix + tokenizer.mask_token,\\n\",\n    \"                         max_length=min(prompt_length, max_length),\\n\",\n    \"                         truncation=\\\"only_first\\\",\\n\",\n    \"                         return_tensors=\\\"pt\\\",\\n\",\n    \"                         return_token_type_ids=False)\\n\",\n    \"max_gen_length = max_length - encoded_dict['input_ids'].shape[1]\\n\",\n    \"inputs = tokenizer.build_inputs_for_generation(encoded_dict, max_gen_length=max_gen_length, padding=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"a1fc75f2\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"inputs\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"bfd9e66a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"batch_size, seq_length = encoded_dict['input_ids'].shape[:2]\\n\",\n    \"attention_mask1 = encoded_dict['attention_mask']\\n\",\n    \"attention_mask2 = attention_mask1.unsqueeze(1).expand(-1, seq_length + max_gen_length, -1)\\n\",\n    \"generation_attention_mask = torch.cat([attention_mask2.new_zeros((seq_length, max_gen_length)),\\n\",\n    \"                                       torch.tril(attention_mask2.new_ones((max_gen_length, max_gen_length)))],\\n\",\n    \"                                      dim=0).unsqueeze(0).expand(batch_size, -1, -1)\\n\",\n    \"attention_mask3 = torch.cat((attention_mask2, generation_attention_mask), dim=2)\\n\",\n    \"attention_mask4 = attention_mask3.unsqueeze(1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"5b629712\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"print(attention_mask1.shape)\\n\",\n    \"print(attention_mask1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"6fd9341e\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"print(attention_mask2.shape)\\n\",\n    \"# print(attention_mask2)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"c9f75428\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"print(attention_mask3.shape)\\n\",\n    \"print(attention_mask3)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"4b600887\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"print(attention_mask4.shape)\\n\",\n    \"print(attention_mask4)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"7e028572\",\n   \"metadata\": {},\n   \"source\": [\n    \"# ChatGLM\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"c25ba334\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_name_or_path = os.path.join(root, \\\"models\\\", \\\"chatglm-6B\\\")\\n\",\n    \"# model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)\\n\",\n    \"model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, trust_remote_code=True)\\n\",\n    \"# model = model.half().to(device)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"cc68c1cc\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"texts = [\\n\",\n    \"    \\\"你好\\\",\\n\",\n    \"    \\\"晚上睡不着应该怎么办\\\"\\n\",\n    \"]\\n\",\n    \"history = []\\n\",\n    \"for text in texts:\\n\",\n    \"    response, history = model.chat(tokenizer, text, history=history)\\n\",\n    \"    print(f\\\"问: {text}\\\\n答:{response}\\\\n\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"b4ccfa2b\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"max_length = 512\\n\",\n    \"prompt = \\\"晚上睡不着应该怎么办\\\"\\n\",\n    \"label = \\\"吃褪黑素\\\"\\n\",\n    \"encoded_dict = tokenizer(prompt, label,\\n\",\n    \"                         max_length=max_length, return_tensors=\\\"pt\\\",\\n\",\n    \"                         truncation=\\\"longest_first\\\", padding=\\\"max_length\\\", return_token_type_ids=False)\\n\",\n    \"print(encoded_dict.keys())\\n\",\n    \"print(encoded_dict['input_ids'].shape)\\n\",\n    \"print(encoded_dict['input_ids'][0, :50])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"50baebfc\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"ids = [ 94747, 103400,  84213,  86846, 150001, 150004,  20005,  84020, 110857,\\n\",\n    \"         84248,  84886,  20003]\\n\",\n    \"print([tokenizer._convert_id_to_token(i) for i in ids])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"f9104b06\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"input_ids = encoded_dict['input_ids']\\n\",\n    \"print(input_ids[0, :20])\\n\",\n    \"seqs = input_ids.tolist()\\n\",\n    \"mask_positions = [seq.index(tokenizer.gmask_token_id) for seq in seqs]\\n\",\n    \"gmask = True\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"b7b577cc\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"batch_size, seq_length = input_ids.shape\\n\",\n    \"context_lengths = [seq.tolist().index(tokenizer.bos_token_id) for seq in input_ids]\\n\",\n    \"attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)\\n\",\n    \"attention_mask.tril_()\\n\",\n    \"for i, context_length in enumerate(context_lengths):\\n\",\n    \"    attention_mask[i, :, :context_length] = 1\\n\",\n    \"attention_mask.unsqueeze_(1)\\n\",\n    \"# attention_mask = (attention_mask < 0.5).bool()\\n\",\n    \"print(attention_mask.shape)\\n\",\n    \"print(attention_mask[0, :20, :20])\\n\",\n    \"batch_size, seq_length = input_ids.shape\\n\",\n    \"context_lengths = [seq.tolist().index(tokenizer.bos_token_id) for seq in input_ids]\\n\",\n    \"# if self.position_encoding_2d:\\n\",\n    \"position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)\\n\",\n    \"if not gmask:\\n\",\n    \"    for i, context_length in enumerate(context_lengths):\\n\",\n    \"        position_ids[i, context_length:] = mask_positions[i]\\n\",\n    \"block_position_ids = [torch.cat((\\n\",\n    \"    torch.zeros(context_length, dtype=torch.long, device=device),\\n\",\n    \"    torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1\\n\",\n    \")) for context_length in context_lengths]\\n\",\n    \"block_position_ids = torch.stack(block_position_ids, dim=0)\\n\",\n    \"position_ids = torch.stack((position_ids, block_position_ids), dim=1)\\n\",\n    \"# else:\\n\",\n    \"#     position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)\\n\",\n    \"#     if not gmask:\\n\",\n    \"#         for i, context_length in enumerate(context_lengths):\\n\",\n    \"#             position_ids[context_length:] = mask_positions[i]\\n\",\n    \"print(position_ids.shape)\\n\",\n    \"print(position_ids[0, :, :20])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"d91e23d9\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"print(tokenizer.tokenize(prompt))\\n\",\n    \"print(tokenizer(prompt))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"e5d10f8c\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"encoded_prompt = tokenizer(prompt)\\n\",\n    \"prompt_length = len(encoded_prompt['input_ids'])\\n\",\n    \"inputs = tokenizer(prompt, max_length=min(prompt_length, max_length), truncation=\\\"only_first\\\",\\n\",\n    \"                   return_tensors=\\\"pt\\\")\\n\",\n    \"inputs = inputs.to(device)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"d6cd0dfa\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"max_length_generation = 50\\n\",\n    \"num_return_sequences = 1\\n\",\n    \"top_p = 0.8\\n\",\n    \"temperature = 1.0\\n\",\n    \"outputs = model.generate(**inputs, max_new_tokens=max_length_generation,\\n\",\n    \"                         eos_token_id=tokenizer.eop_token_id,\\n\",\n    \"                         pad_token_id=tokenizer.pad_token_id,\\n\",\n    \"                         do_sample=False,\\n\",\n    \"                         num_return_sequences=num_return_sequences,\\n\",\n    \"                         top_p=top_p,\\n\",\n    \"                         temperature=temperature)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"6f13279f\",\n   \"metadata\": {},\n   \"source\": [\n    \"# DeepSpeed\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"7798b57b\",\n   \"metadata\": {},\n   \"source\": [\n    \"### estimate the memory needs for params, optim states and gradients\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"07493af2\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# stage 1 and 2\\n\",\n    \"from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live\\n\",\n    \"estimate_zero2_model_states_mem_needs_all_live(model, num_gpus_per_node=3, num_nodes=1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"4cb3b908\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# stage 3\\n\",\n    \"from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live\\n\",\n    \"estimate_zero3_model_states_mem_needs_all_live(model, num_gpus_per_node=3, num_nodes=1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e9aa232d\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Split torch checkpoint into multiple checkpoints\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"705292a3\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"checkpoint = os.path.join(root, \\\"chatgpt\\\", \\\"output\\\", \\\"sft\\\", \\\"pangu-2.6B\\\", \\\"checkpoint-42782\\\")\\n\",\n    \"if \\\"glm\\\" in model_name_or_path:\\n\",\n    \"    model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, trust_remote_code=True)\\n\",\n    \"else:\\n\",\n    \"    model = AutoModelForCausalLM.from_pretrained(checkpoint, trust_remote_code=True, use_cache=False)\\n\",\n    \"st = model.state_dict()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"e6d5ee32\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import sys\\n\",\n    \"print(sys.getsizeof(st))\\n\",\n    \"keys = list(st.keys())\\n\",\n    \"n = 10\\n\",\n    \"m = {\\\"metadata\\\": {\\\"total_size\\\":sys.getsizeof(st)}, \\\"weight_map\\\":dict()}\\n\",\n    \"span = len(keys) // n\\n\",\n    \"for i in range(n):\\n\",\n    \"    fn = f\\\"pytorch_model-{i+1}-of-{n}.bin\\\"\\n\",\n    \"    f = os.path.join(checkpoint, fn)\\n\",\n    \"    stt = dict()\\n\",\n    \"    for key in keys[i*span:(i+1)*span]:\\n\",\n    \"        stt[key] = st[key]\\n\",\n    \"        m[\\\"weight_map\\\"][key] = fn\\n\",\n    \"    torch.save(stt, f)\\n\",\n    \"f = os.path.join(checkpoint, \\\"pytorch_model.bin.index.json\\\")\\n\",\n    \"json.dump(m, open(f, \\\"w\\\", encoding=\\\"utf-8\\\"), ensure_ascii=False)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.12\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "requirements.txt",
    "content": "jieba\nsentencepiece\nnltk>=3.8.1\nrouge-score>=0.1.2\nevaluate>=0.4.0\nnumpy==1.23.2\ntensorflow>=2.5.0\ntransformers>=4.31.0\ntokenizers>=0.13.3\n#triton==1.0.0\ntriton\nninja\ntorchmetrics\nopencc\ntorchtyping\naccelerate>=0.19.0\nray\ntqdm\nrich\nicetk\ncpm_kernels\nprotobuf==3.20.0\npynvml\nbitsandbytes>=0.39.0\npeft>=0.4.0\ntransformers_stream_generator\ntrl\ntrlx; trlx"
  },
  {
    "path": "src/__init__.py",
    "content": ""
  },
  {
    "path": "src/data/__init__.py",
    "content": ""
  },
  {
    "path": "src/data/data.py",
    "content": "import os\nimport json\nimport re\nimport random\nfrom typing import Tuple, List\n\nimport torch\nimport pandas as pd\n\nfrom tqdm import tqdm\nfrom torch.utils.data import Dataset\nfrom torch.nn.utils.rnn import pad_sequence\nimport torch.nn.functional as F\nfrom transformers import PreTrainedTokenizerBase\n\nfrom src.utils import logger, RESOURCE_PATH\nfrom src.utils.modeling_utils import _prepare_decoder_attention_mask, qwen_make_context\nfrom src.utils.file_utils import print_rank_0\n\n\ndef chatglm3_encode(tokenizer: PreTrainedTokenizerBase,\n                    query: str,\n                    label: str = None,\n                    system: str = \"\",\n                    max_length: int = 1024,\n                    is_prefix: bool = True\n                    ) -> Tuple[List[int], List[int], List[int]]:\n    '''Use chatglm3 tokenizer to encode prompt + label with \"longest_first\" truncation strategy\n\n    :param tokenizer:\n    :param prompt:\n    :param label:\n    :param system:\n    :param max_length:\n    :return:\n    '''\n    prefix_tokens = tokenizer.get_prefix_tokens()\n    role_tokens_1 = [tokenizer.get_command(f\"<|user|>\")] + tokenizer.encode(f\"\\n\", add_special_tokens=False)\n    # Process `system` and `query`\n    if is_prefix:\n        system_ids = tokenizer.encode(system + \"\\n\\n\", add_special_tokens=False) if len(system) > 0 else []\n        query_ids = tokenizer.encode(\" \" + query, add_special_tokens=False)[1:]\n    else:\n        system_ids = tokenizer.encode(\" \\n\\n\" + system, add_special_tokens=False)[1:] if len(system) > 0 else []\n        query_ids = tokenizer.encode(query, add_special_tokens=False)\n    # Process `label`\n    role_tokens_2 = [tokenizer.get_command(f\"<|assistant|>\")]\n    if label is not None:\n        label_ids = tokenizer.encode(label, add_special_tokens=False)\n        end_tokens = [tokenizer.get_command(\"<eos>\")]\n    else:\n        label_ids = []\n        end_tokens = []\n    # Remove overflowing tokens\n    num_tokens_to_remove = len(prefix_tokens) + len(role_tokens_1) + len(query_ids) + len(system_ids) + \\\n                           len(role_tokens_2) + len(label_ids) + len(end_tokens) - max_length\n    if num_tokens_to_remove > 0:\n        for _ in range(num_tokens_to_remove):\n            if len(query_ids) + len(system_ids) > len(label_ids) and len(query_ids) > 0:\n                query_ids.pop()\n            elif len(label_ids) > 0:\n                label_ids.pop()\n            else:\n                logger.warn(\"removing system tokens due to tokens overflowing\")\n                system_ids.pop()\n        if label is not None:\n            label_ids += end_tokens\n    else:\n        if label is not None:\n            label_ids += end_tokens\n        label_ids += [tokenizer.pad_token_id] * -num_tokens_to_remove\n\n    if is_prefix:\n        prompt_ids = prefix_tokens + role_tokens_1 + system_ids + query_ids + role_tokens_2\n    else:\n        prompt_ids = prefix_tokens + role_tokens_1 + query_ids + system_ids + role_tokens_2\n    input_ids = prompt_ids + label_ids\n    labels = [tokenizer.pad_token_id] * len(prompt_ids) + label_ids\n    assert len(input_ids) == len(labels) == max_length\n    return input_ids, labels, prompt_ids\n\n\ndef chatglm2_encode(tokenizer: PreTrainedTokenizerBase,\n                    query: str,\n                    label: str = None,\n                    system: str = \"\",\n                    max_length: int = 1024,\n                    is_prefix: bool = True\n                    ) -> Tuple[List[int], List[int], List[int]]:\n    '''Use chatglm2 tokenizer to encode prompt + label with \"longest_first\" truncation strategy\n\n    :param tokenizer:\n    :param prompt:\n    :param label:\n    :param system:\n    :param max_length:\n    :return:\n    '''\n    gmask_id = tokenizer.get_command(\"[gMASK]\")\n    sop_id = tokenizer.get_command(\"sop\")\n    eop_id = tokenizer.get_command(\"eop\")\n    # [Round {1}]\\n\\n问：\n    ids1 = [790, 30951, 517, 30910, 30939, 30996, 13, 13, 54761, 31211]\n    # \\n\\n答：\n    ids2 = [13, 13, 55437, 31211]\n    if len(system) > 0:\n        if is_prefix:\n            system_ids = tokenizer.encode(\" \" + system + \"\\n\\n\", add_special_tokens=False)[1:]\n        else:\n            system_ids = tokenizer.encode(\" \\n\\n\" + system, add_special_tokens=False)[1:]\n    else:\n        system_ids = []\n    query_ids = tokenizer.encode(\" \" + query, add_special_tokens=False)[1:]\n    if label is not None:\n        label_ids = tokenizer.encode(label, add_special_tokens=False)\n        num_special_tokens = 3\n    else:\n        label_ids = []\n        num_special_tokens = 2\n    num_tokens_to_remove = len(ids1) + len(query_ids) + len(system_ids) + len(ids2) + \\\n                           len(label_ids) + num_special_tokens - max_length\n    if num_tokens_to_remove > 0:\n        for _ in range(num_tokens_to_remove):\n            if len(query_ids) + len(system_ids) > len(label_ids) and len(query_ids) > 0:\n                query_ids.pop()\n            elif len(label_ids) > 0:\n                label_ids.pop()\n            else:\n                logger.warn(\"removing system tokens due to tokens overflowing\")\n                system_ids.pop()\n        if label is not None:\n            label_ids += [eop_id]\n    else:\n        if label is not None:\n            label_ids += [eop_id]\n        label_ids += [tokenizer.pad_token_id] * -num_tokens_to_remove\n    if is_prefix:\n        prompt_ids = [gmask_id, sop_id] + ids1 + system_ids + query_ids + ids2\n    else:\n        prompt_ids = [gmask_id, sop_id] + ids1 + query_ids + system_ids + ids2\n    input_ids = prompt_ids + label_ids\n    labels = [tokenizer.pad_token_id] * len(prompt_ids) + label_ids\n    assert len(input_ids) == len(labels) == max_length\n    return input_ids, labels, prompt_ids\n\n\nclass DataCollatorReward:\n    def __call__(self, data):\n        has_attention_mask = 'attention_mask' in data[0]\n        batch = {\n            \"chosen_input_ids\": torch.stack([f['input_ids'] for f in data]),\n            \"chosen_attention_mask\": torch.stack([f['attention_mask'] for f in data]) if has_attention_mask else None,\n            # \"input_ids\": torch.cat([f[0] for f in data] + [f[2] for f in data]),\n            # \"attention_mask\": torch.cat([f[1] for f in data] + [f[3] for f in data]),\n            # \"labels\": torch.tensor([0] * len(data) + [1] * len(data))\n        }\n        return batch\n\n\nclass DataCollatorRLHF:\n\n    def __init__(self, max_token_len, inference_tp_size):\n        self.max_token_len = max_token_len\n        self.inference_tp_size = inference_tp_size\n\n    def __call__(self, data):\n        batch = {}\n        pad_token_id = data[-1][-1]\n\n        prompt = pad_sequence([f[0] for f in data],\n                              padding_value=pad_token_id,\n                              batch_first=True)\n        prompt_mask = pad_sequence([f[1] for f in data],\n                                   padding_value=0,\n                                   batch_first=True)\n\n        ### make sure the final ouput is a seqence of 2**?\n        length = prompt.size()[-1]\n        pad_length = self.max_token_len - length\n        if pad_length > 0:\n            batch[\"prompt\"] = F.pad(prompt,\n                                    pad=(pad_length, 0),\n                                    mode='constant',\n                                    value=pad_token_id)\n            batch[\"prompt_att_mask\"] = F.pad(prompt_mask,\n                                             pad=(pad_length, 0),\n                                             mode='constant',\n                                             value=0)\n        else:\n            batch[\"prompt\"] = prompt\n            batch[\"prompt_att_mask\"] = prompt_mask\n        batch[\"prompt\"] = batch[\"prompt\"].flip(1)\n        batch[\"prompt_att_mask\"] = batch[\"prompt_att_mask\"].flip(1)\n        return batch\n\n\nclass PretrainDataset(Dataset):\n    def __init__(self, args, filename, tokenizer, concat_samples=True):\n        self.args = args\n        self.tokenizer = tokenizer\n        self.concat_samples = concat_samples\n        self.model_name_or_path = args.model_name_or_path if hasattr(args,\n                                                                     \"model_name_or_path\") else args.actor_model_path\n\n        self.post_list = self.load_dataset(filename)\n        for k in range(5):\n            print_rank_0(f\"PretrainDataset sample-{k}\\n: {self.post_list[k]}\")\n\n    def __len__(self):\n        return len(self.post_list)\n\n    def __getitem__(self, idx):\n        data = self.post_list[idx]\n        if not self.concat_samples:\n            prompt = data['prompt']\n            label = data.get('label', None)\n            if \"glm\" in self.model_name_or_path.lower() and \"chatglm\" not in self.model_name_or_path.lower():\n                encoded_prompt = self.tokenizer(prompt, self.tokenizer.mask_token)\n                prompt_length = len(encoded_prompt['input_ids'])\n                label_length = len(self.tokenizer.tokenize(label)) + 1\n                if prompt_length + label_length > self.args.max_length:\n                    num_tokens_to_remove = prompt_length + label_length - self.args.max_length\n                    for _ in range(num_tokens_to_remove):\n                        if prompt_length > label_length:\n                            prompt_length -= 1\n                        else:\n                            label_length -= 1\n                else:\n                    label_length = self.args.max_length - prompt_length\n                assert prompt_length > 0\n                assert label_length > 0\n                assert prompt_length + label_length == self.args.max_length\n                encoded_dict = self.tokenizer(prompt, self.tokenizer.mask_token,\n                                              max_length=prompt_length,\n                                              truncation=\"only_first\",\n                                              return_tensors=\"pt\",\n                                              return_attention_mask=True,\n                                              return_token_type_ids=False)\n                encoded_dict = self.tokenizer.build_inputs_for_generation(encoded_dict, targets=label,\n                                                                          max_gen_length=label_length, padding=True)\n                return {\n                    \"input_ids\": encoded_dict['input_ids'][0],\n                    \"position_ids\": encoded_dict['position_ids'][0],\n                    \"attention_mask\": encoded_dict['attention_mask'][0],\n                    \"labels\": encoded_dict['labels'][0],\n                }\n            else:\n                if \"chatglm2\" in self.model_name_or_path.lower():\n                    prompt = f\"[Round {1}]\\n\\n问：{prompt}\\n\\n答：\"\n                    label = label\n                elif \"chatglm\" in self.model_name_or_path.lower():\n                    prompt = f\"[Round {0}]\\n问：{prompt}\\n答：\"\n                    label = label\n                elif \"vicuna\" in self.model_name_or_path.lower():\n                    prompt += \"\\n\\n\" + label\n                    label = None\n                else:\n                    label = None\n                encoded_dict = self.tokenizer(prompt, label,\n                                              max_length=self.args.max_length,\n                                              truncation=\"longest_first\",\n                                              padding=\"max_length\",\n                                              return_token_type_ids=False,\n                                              return_tensors=\"pt\", )\n                if \"pangu\" in self.model_name_or_path.lower():\n                    return {\n                        \"input_ids\": encoded_dict['input_ids'],\n                        \"attention_mask\": encoded_dict['attention_mask'],\n                        \"labels\": encoded_dict['input_ids'],\n                    }\n                else:\n                    result = {\n                        \"input_ids\": encoded_dict['input_ids'][0],\n                        \"labels\": encoded_dict['input_ids'][0],\n                    }\n                    if 'attention_mask' in encoded_dict:\n                        result[\"attention_mask\"] = encoded_dict['attention_mask'][0]\n                    return result\n        else:\n            eos_ids = data['eos_ids']\n            input_ids = data['input_ids']\n            combined_attention_mask = torch.full((self.args.max_length, self.args.max_length),\n                                                 torch.tensor(torch.finfo(torch.float16).min))\n            for i in range(len(eos_ids) - 1):\n                attention_mask = torch.ones((1, eos_ids[i + 1] - eos_ids[i]), dtype=torch.long)\n                attention_mask = _prepare_decoder_attention_mask(attention_mask, attention_mask.shape,\n                                                                 input_embeds=torch.ones(1, dtype=torch.float16,\n                                                                                         device=\"cpu\"),\n                                                                 past_key_values_length=0)\n                logger.debug(f\"{i}-th sample, shape: {attention_mask.shape}, attention_mask: {attention_mask}\")\n                combined_attention_mask[eos_ids[i]:eos_ids[i + 1], eos_ids[i]:eos_ids[i + 1]] = attention_mask\n            logger.debug(f\"shape: {combined_attention_mask.shape}, combined_attention_mask: {combined_attention_mask}\")\n            if \"chatglm2\" in self.model_name_or_path.lower():\n                return {\n                    \"input_ids\": input_ids,\n                    \"labels\": input_ids,\n                    \"full_attention_mask\": combined_attention_mask,\n                }\n            else:\n                return {\n                    \"input_ids\": input_ids,\n                    \"labels\": input_ids,\n                    \"attention_mask\": combined_attention_mask,\n                }\n\n    def load_dataset(self, filename):\n        discard = 0\n        datasets = []\n        with open(filename, \"r\", encoding=\"utf-8\") as f:\n            data = []\n            eos_ids = [0]\n            length = 0\n            for i, line in tqdm(enumerate(f), desc=f\"Loading {os.path.basename(filename)}\"):\n                item = json.loads(line)\n                prompt = str(item['prompt'])\n                label = item.get('label', None)\n                if len(prompt) <= 0:\n                    discard += 1\n                    continue\n                if not self.concat_samples:\n                    datasets.append({\"prompt\": prompt, \"label\": label})\n                else:\n                    if \"chatglm2\" not in self.model_name_or_path.lower():\n                        assert \"glm\" not in self.model_name_or_path.lower(), \\\n                            \"Concatenating samples for GLM or ChatGLM not implemented yet\"\n                    if \"chatglm2\" in self.model_name_or_path.lower():\n                        prompt = f\"[Round {1}]\\n\\n问：{prompt}\\n\\n答：\"\n                    else:\n                        prompt = prompt if label is None else \"\\n\\n\".join((prompt, label))\n                        label = None\n                    token_ids = self.tokenizer.encode(prompt, label,\n                                                      max_length=self.args.max_length - length,\n                                                      truncation=\"longest_first\")\n                    if length + len(token_ids) < self.args.max_length:\n                        data.extend(token_ids)\n                        length += len(token_ids)\n                        eos_ids.append(length)\n                    else:\n                        data.extend(token_ids[:(self.args.max_length - length)])\n                        eos_ids.append(self.args.max_length)\n                        datasets.append({\"input_ids\": data, \"eos_ids\": eos_ids})\n                        data = []\n                        eos_ids = [0]\n                        length = 0\n        print_rank_0(\n            f\"Finished loading {os.path.basename(filename)}, # samples: {len(datasets)}, # discarded: {discard}\")\n\n        return datasets\n\n\nclass SFTDataset(Dataset):\n    def __init__(self, args, filename, tokenizer, concat_samples=True):\n        self.args = args\n        self.tokenizer = tokenizer\n        self.concat_samples = concat_samples\n        self.model_name_or_path = args.model_name_or_path if hasattr(args,\n                                                                     \"model_name_or_path\") else args.actor_model_path\n\n        self.post_list = self.load_dataset(filename)\n        for k in range(5):\n            print_rank_0(f\"SFTDataset sample-{k}\\n: {self.post_list[k]}\")\n\n    def __len__(self):\n        return len(self.post_list)\n\n    def __getitem__(self, idx):\n        data = self.post_list[idx]\n        if not self.concat_samples:\n            prompt = data['prompt']\n            label = data['label']\n            prefix = data['prefix']\n            system = data['system']\n            if \"glm\" in self.model_name_or_path.lower() and \"chatglm\" not in self.model_name_or_path.lower():\n                encoded_prompt = self.tokenizer(prompt, prefix + self.tokenizer.mask_token)\n                prompt_length = len(encoded_prompt['input_ids'])\n                label_length = len(self.tokenizer.tokenize(label)) + 1\n                if prompt_length + label_length > self.args.max_length:\n                    num_tokens_to_remove = prompt_length + label_length - self.args.max_length\n                    for _ in range(num_tokens_to_remove):\n                        if prompt_length > label_length:\n                            prompt_length -= 1\n                        else:\n                            label_length -= 1\n                else:\n                    label_length = self.args.max_length - prompt_length\n                assert prompt_length > 0\n                assert label_length > 0\n                assert prompt_length + label_length == self.args.max_length\n                encoded_dict = self.tokenizer(prompt, prefix + self.tokenizer.mask_token,\n                                              max_length=prompt_length,\n                                              truncation=\"only_first\",\n                                              return_tensors=\"pt\",\n                                              return_attention_mask=True,\n                                              return_token_type_ids=False)\n                encoded_dict = self.tokenizer.build_inputs_for_generation(encoded_dict, targets=label,\n                                                                          max_gen_length=label_length, padding=True)\n                return {\n                    \"input_ids\": encoded_dict['input_ids'][0],\n                    \"position_ids\": encoded_dict['position_ids'][0],\n                    \"attention_mask\": encoded_dict['attention_mask'][0],\n                    \"labels\": encoded_dict['labels'][0],\n                }\n            elif \"pangu\" in self.model_name_or_path.lower():\n                label = prefix + label\n                encoded_dict = self.tokenizer(prompt, label,\n                                              max_length=self.args.max_length,\n                                              truncation=\"longest_first\",\n                                              padding=\"max_length\",\n                                              return_token_type_ids=False,\n                                              return_tensors=\"pt\", )\n                return {\n                    \"input_ids\": encoded_dict['input_ids'],\n                    \"attention_mask\": encoded_dict['attention_mask'],\n                    \"labels\": encoded_dict['input_ids'],\n                }\n            elif \"chatglm3\" in self.model_name_or_path.lower():\n                input_ids, labels, _ = chatglm3_encode(self.tokenizer, prompt, label, system, self.args.max_length)\n                return {\n                    \"input_ids\": torch.tensor(input_ids, dtype=torch.long),\n                    # \"attention_mask\": torch.ones(len(input_ids), dtype=torch.long),\n                    \"labels\": torch.tensor(labels, dtype=torch.long),\n                }\n            elif \"chatglm2\" in self.model_name_or_path.lower():\n                input_ids, labels, _ = chatglm2_encode(self.tokenizer, prompt, label, system, self.args.max_length)\n                # gmask_id = self.tokenizer.get_command(\"[gMASK]\")\n                # sop_id = self.tokenizer.get_command(\"sop\")\n                # eop_id = self.tokenizer.get_command(\"eop\")\n                # # [Round {1}]\\n\\n问：\n                # ids1 = [790, 30951, 517, 30910, 30939, 30996, 13, 13, 54761, 31211]\n                # # \\n\\n答：\n                # ids2 = [13, 13, 55437, 31211]\n                # prompt = \"\\n\\n\".join((system, prompt))\n                # prompt_ids = self.tokenizer.encode(\" \" + prompt, add_special_tokens=False)[1:]\n                # label_ids = self.tokenizer.encode(label, add_special_tokens=False)\n                # num_tokens_to_remove = len(ids1) + len(prompt_ids) + len(ids2) + len(label_ids) + 3 - self.args.max_length\n                # if num_tokens_to_remove > 0:\n                #     for _ in range(num_tokens_to_remove):\n                #         if len(prompt_ids) > len(label_ids):\n                #             prompt_ids.pop()\n                #         else:\n                #             label_ids.pop()\n                #     prompt_ids = [gmask_id, sop_id] + ids1 + prompt_ids + ids2\n                #     label_ids = label_ids + [eop_id]\n                # else:\n                #     prompt_ids = [gmask_id, sop_id] + ids1 + prompt_ids + ids2\n                #     label_ids = label_ids + [eop_id] + [self.tokenizer.pad_token_id] * -num_tokens_to_remove\n                # input_ids = prompt_ids + label_ids\n                # labels = [self.tokenizer.pad_token_id] * len(prompt_ids) + label_ids\n                # assert len(input_ids) == len(labels) == self.args.max_length\n                return {\n                    \"input_ids\": torch.tensor(input_ids, dtype=torch.long),\n                    # \"attention_mask\": torch.ones(len(input_ids), dtype=torch.long),\n                    \"labels\": torch.tensor(labels, dtype=torch.long),\n                }\n            elif \"chatglm\" in self.model_name_or_path.lower():\n                prompt = f\"[Round {0}]\\n问：{prompt}\\n答：\"\n                encoded_dict = self.tokenizer(prompt, label,\n                                              max_length=self.args.max_length,\n                                              truncation=\"longest_first\",\n                                              padding=\"max_length\",\n                                              return_token_type_ids=False,\n                                              return_tensors=\"pt\", )\n                return {\n                    \"input_ids\": encoded_dict['input_ids'][0],\n                    \"attention_mask\": encoded_dict['attention_mask'][0],\n                    \"labels\": encoded_dict['input_ids'][0],\n                }\n            else:\n                encoded_dict = self.tokenizer(prompt, label,\n                                              max_length=self.args.max_length,\n                                              truncation=\"longest_first\",\n                                              padding=\"max_length\",\n                                              return_token_type_ids=False,\n                                              return_tensors=\"pt\", )\n                result = {\n                    \"input_ids\": encoded_dict['input_ids'][0],\n                    \"labels\": encoded_dict['input_ids'][0],\n                }\n                if 'attention_mask' in encoded_dict:\n                    result[\"attention_mask\"] = encoded_dict['attention_mask'][0]\n                return result\n        else:\n            eos_ids = data['eos_ids']\n            input_ids = data['input_ids']\n            combined_attention_mask = torch.full((self.args.max_length, self.args.max_length),\n                                                 torch.tensor(torch.finfo(torch.float16).min))\n            for i in range(len(eos_ids) - 1):\n                attention_mask = torch.ones((1, eos_ids[i + 1] - eos_ids[i]), dtype=torch.long)\n                attention_mask = _prepare_decoder_attention_mask(attention_mask, attention_mask.shape,\n                                                                 input_embeds=torch.ones(1, dtype=torch.float16,\n                                                                                         device=\"cpu\"),\n                                                                 past_key_values_length=0)\n                logger.debug(f\"{i}-th sample, shape: {attention_mask.shape}, attention_mask: {attention_mask}\")\n                combined_attention_mask[eos_ids[i]:eos_ids[i + 1], eos_ids[i]:eos_ids[i + 1]] = attention_mask\n            logger.debug(f\"shape: {combined_attention_mask.shape}, combined_attention_mask: {combined_attention_mask}\")\n            if \"chatglm2\" in self.model_name_or_path.lower():\n                return {\n                    \"input_ids\": input_ids,\n                    \"labels\": input_ids,\n                    \"full_attention_mask\": combined_attention_mask,\n                }\n            else:\n                return {\n                    \"input_ids\": input_ids,\n                    \"labels\": input_ids,\n                    \"attention_mask\": combined_attention_mask,\n                }\n\n    def load_dataset(self, filename):\n        discard = 0\n        datasets = []\n        with open(filename, \"r\", encoding=\"utf-8\") as f:\n            data = []\n            eos_ids = [0]\n            length = 0\n            for i, line in tqdm(enumerate(f), desc=f\"Loading {os.path.basename(filename)}\"):\n                item = json.loads(line)\n                data_type = item.get('data_type', \"human_generated\")\n                if data_type != \"human_generated\":\n                    continue\n                prompt = str(item['prompt'])\n                label = str(item['answers'][0]['answer'])\n                score = item['answers'][0]['score']\n                prefix = item.get('prefix', \"\")\n                system = item.get('system', \"\")\n                if len(prompt) <= 0 or len(label) <= 0:\n                    discard += 1\n                    continue\n\n                if not self.concat_samples:\n                    datasets.append({\"prompt\": prompt, \"label\": label, \"prefix\": prefix, \"system\": system})\n                else:\n                    if \"chatglm2\" not in self.model_name_or_path.lower():\n                        assert \"glm\" not in self.model_name_or_path.lower(), \\\n                            \"Concatenating samples for GLM or ChatGLM not implemented yet\"\n                    else:\n                        if \"chatglm2\" in self.model_name_or_path.lower():\n                            prompt = f\"[Round {1}]\\n\\n问：{prompt}\\n\\n答：\"\n                        else:\n                            prompt = prompt if label is None else \"\\n\\n\".join((prompt, label))\n                            label = None\n                        token_ids = self.tokenizer.encode(prompt, label,\n                                                          max_length=self.args.max_length - length,\n                                                          truncation=\"longest_first\")\n                        if length + len(token_ids) < self.args.max_length:\n                            data.extend(token_ids)\n                            length += len(token_ids)\n                            eos_ids.append(length)\n                        else:\n                            data.extend(token_ids[:(self.args.max_length - length)])\n                            eos_ids.append(self.args.max_length)\n                            datasets.append({\"input_ids\": data, \"eos_ids\": eos_ids})\n                            data = []\n                            eos_ids = [0]\n                            length = 0\n\n        print_rank_0(\n            f\"Finished loading {os.path.basename(filename)}, # samples: {len(datasets)}, # discarded: {discard}\")\n\n        return datasets\n\n\nclass PairwiseDataset(Dataset):\n    def __init__(self, args, filename, tokenizer):\n        self.pairs = self.load_dataset(filename)\n        self.args = args\n        self.tokenizer = tokenizer\n\n        for k in range(5):\n            print_rank_0(f\"PairwiseDataset sample-{k}\\n: {self.pairs[k]}\")\n\n    def __len__(self):\n        return len(self.pairs)\n\n    def __getitem__(self, idx):\n        pair = self.pairs[idx]\n\n        prompt = pair[\"prompt\"]\n        chosen_answer = pair[\"chosen_answer\"]\n        rejected_answer = pair[\"rejected_answer\"]\n        prefix = pair['prefix']\n        system = pair['system']\n        if \"pangu\" in self.args.model_name_or_path.lower():\n            chosen_encodings_dict = self.tokenizer(prompt, prefix + chosen_answer, max_length=self.args.max_length,\n                                                   truncation=\"longest_first\", padding=\"max_length\",\n                                                   return_tensors=\"pt\",\n                                                   return_token_type_ids=False)\n            rejected_encodings_dict = self.tokenizer(prompt, prefix + rejected_answer, max_length=self.args.max_length,\n                                                     truncation=\"longest_first\", padding=\"max_length\",\n                                                     return_tensors=\"pt\",\n                                                     return_token_type_ids=False)\n            return {\n                \"chosen_input_ids\": chosen_encodings_dict[\"input_ids\"],\n                \"chosen_attention_mask\": chosen_encodings_dict[\"attention_mask\"],\n                \"rejected_input_ids\": rejected_encodings_dict[\"input_ids\"],\n                \"rejected_attention_mask\": rejected_encodings_dict[\"attention_mask\"],\n                \"labels\": rejected_encodings_dict[\"input_ids\"],\n            }\n        elif \"chatglm3\" in self.args.model_name_or_path.lower():\n            chosen_input_ids, labels, _ = chatglm3_encode(self.tokenizer, prompt, chosen_answer, system,\n                                                          self.args.max_length)\n            rejected_input_ids, labels, _ = chatglm3_encode(self.tokenizer, prompt, rejected_answer, system,\n                                                            self.args.max_length)\n            return {\n                \"chosen_input_ids\": torch.tensor(chosen_input_ids, dtype=torch.long),\n                \"rejected_input_ids\": torch.tensor(rejected_input_ids, dtype=torch.long),\n                \"labels\": torch.tensor(labels, dtype=torch.long)\n            }\n        elif \"chatglm2\" in self.args.model_name_or_path.lower():\n            chosen_input_ids, labels, _ = chatglm2_encode(self.tokenizer, prompt, chosen_answer, system,\n                                                          self.args.max_length)\n            rejected_input_ids, labels, _ = chatglm2_encode(self.tokenizer, prompt, rejected_answer, system,\n                                                            self.args.max_length)\n            return {\n                \"chosen_input_ids\": torch.tensor(chosen_input_ids, dtype=torch.long),\n                \"rejected_input_ids\": torch.tensor(rejected_input_ids, dtype=torch.long),\n                \"labels\": torch.tensor(labels, dtype=torch.long)\n            }\n        elif \"chatglm\" in self.args.model_name_or_path.lower():\n            prompt = f\"[Round {0}]\\n问：{prompt}\\n答：\"\n            chosen_encodings_dict = self.tokenizer(prompt, chosen_answer, max_length=self.args.max_length,\n                                                   truncation=\"longest_first\", padding=\"max_length\",\n                                                   return_tensors=\"pt\")\n            rejected_encodings_dict = self.tokenizer(prompt, rejected_answer, max_length=self.args.max_length,\n                                                     truncation=\"longest_first\", padding=\"max_length\",\n                                                     return_tensors=\"pt\")\n            return {\n                \"chosen_input_ids\": chosen_encodings_dict[\"input_ids\"][0],\n                \"rejected_input_ids\": rejected_encodings_dict[\"input_ids\"][0],\n                \"labels\": rejected_encodings_dict[\"input_ids\"][0],\n            }\n        elif \"glm\" in self.args.model_name_or_path.lower():\n            chosen_prompt_length = len(self.tokenizer.tokenize(prompt + prefix)) + 4\n            rejected_prompt_length = chosen_prompt_length\n            chosen_answer_length = len(self.tokenizer.tokenize(chosen_answer)) + 1\n            if chosen_prompt_length + chosen_answer_length > self.args.max_length:\n                if chosen_prompt_length >= chosen_answer_length:\n                    chosen_prompt_length -= chosen_prompt_length + chosen_answer_length - self.args.max_length\n                else:\n                    chosen_answer_length -= chosen_prompt_length + chosen_answer_length - self.args.max_length\n            else:\n                chosen_answer_length = self.args.max_length - chosen_prompt_length\n            chosen_encoded_dict = self.tokenizer(prompt, prefix + self.tokenizer.mask_token,\n                                                 max_length=chosen_prompt_length,\n                                                 truncation=\"only_first\",\n                                                 return_tensors=\"pt\",\n                                                 return_token_type_ids=False)\n            chosen_encodings_dict = self.tokenizer.build_inputs_for_generation(chosen_encoded_dict,\n                                                                               targets=chosen_answer,\n                                                                               max_gen_length=chosen_answer_length,\n                                                                               padding=True)\n\n            rejected_answer_length = len(self.tokenizer.tokenize(rejected_answer)) + 1\n            if rejected_prompt_length + rejected_answer_length > self.args.max_length:\n                if rejected_prompt_length >= rejected_answer_length:\n                    rejected_prompt_length -= rejected_prompt_length + rejected_answer_length - self.args.max_length\n                else:\n                    rejected_answer_length -= rejected_prompt_length + rejected_answer_length - self.args.max_length\n            else:\n                rejected_answer_length = self.args.max_length - rejected_prompt_length\n            rejected_encoded_dict = self.tokenizer(prompt, prefix + self.tokenizer.mask_token,\n                                                   max_length=rejected_prompt_length,\n                                                   truncation=\"only_first\",\n                                                   return_tensors=\"pt\",\n                                                   return_token_type_ids=False)\n            rejected_encodings_dict = self.tokenizer.build_inputs_for_generation(rejected_encoded_dict,\n                                                                                 targets=rejected_answer,\n                                                                                 max_gen_length=rejected_answer_length,\n                                                                                 padding=True)\n            return {\n                \"chosen_input_ids\": chosen_encodings_dict[\"input_ids\"][0],\n                \"chosen_attention_mask\": chosen_encodings_dict[\"attention_mask\"][0],\n                \"chosen_position_ids\": chosen_encodings_dict[\"position_ids\"][0],\n                \"rejected_input_ids\": rejected_encodings_dict[\"input_ids\"][0],\n                \"rejected_attention_mask\": rejected_encodings_dict[\"attention_mask\"][0],\n                \"rejected_position_ids\": rejected_encodings_dict[\"position_ids\"][0],\n                \"labels\": rejected_encodings_dict[\"input_ids\"][0],\n            }\n        else:\n            raise ValueError(f\"Unsupported model name: {self.args.model_name_or_path}\")\n\n    @staticmethod\n    def load_dataset(filename):\n        discard = 0\n        pairs = []\n        with open(filename, \"r\", encoding=\"utf-8\") as f:\n            for line in tqdm(f, desc=f\"Loading {os.path.basename(filename)}\"):\n                item = json.loads(line)\n                prompt = str(item['prompt'])\n                answers = item['answers']\n                prefix = item.get('prefix', \"\")\n                system = item.get('system', \"\")\n                chosen_answer, rejected_answer = None, None\n                for i in range(len(answers) - 1):\n                    answer_1 = str(answers[i][\"answer\"])\n                    answer_1_score = answers[i][\"score\"]\n                    answer_2 = str(answers[i + 1][\"answer\"])\n                    answer_2_score = answers[i + 1][\"score\"]\n                    if answer_1_score > answer_2_score:\n                        chosen_answer = answer_1\n                    rejected_answer = answer_2\n                    if chosen_answer is not None and rejected_answer is not None \\\n                            and len(prompt) > 0 and len(chosen_answer) > 0 and len(rejected_answer) > 0 \\\n                            and chosen_answer != rejected_answer:\n                        pair = {\n                            \"prompt\": prompt,\n                            \"prefix\": prefix,\n                            \"system\": system,\n                            \"chosen_answer\": chosen_answer,\n                            \"rejected_answer\": rejected_answer\n                        }\n                        pairs.append(pair)\n                    else:\n                        discard += 1\n\n        print_rank_0(f\"Finished loading {os.path.basename(filename)}, # pairs: {len(pairs)}, # discarded: {discard}\")\n\n        return pairs\n\n\nclass RLHFDataset(Dataset):\n    def __init__(self, args, filename, tokenizer):\n        self.args = args\n        self.tokenizer = tokenizer\n        assert tokenizer.padding_side == \"left\", \"In RLHF training, need to set padding_side to 'left'\"\n\n        self.post_list = self.load_dataset(filename)\n        for k in range(5):\n            print_rank_0(f\"RLHFDataset sample-{k}\\n: {self.post_list[k]}\")\n\n    def __len__(self):\n        return len(self.post_list)\n\n    def __getitem__(self, idx):\n        data = self.post_list[idx]\n        prompt = data['prompt']\n        prefix = data['prefix']\n        system = data['system']\n        if \"pangu\" in self.args.actor_model_path:\n            encoded_dict = self.tokenizer(prompt, self.tokenizer.sep_token + prefix,\n                                          max_length=self.args.max_prompt_length,\n                                          # padding=\"max_length\",\n                                          truncation=\"only_first\", add_special_tokens=False,\n                                          return_tensors=\"pt\", return_token_type_ids=False)\n            return {\n                \"input_ids\": encoded_dict['input_ids'][0],\n                \"attention_mask\": encoded_dict['attention_mask'][0],\n                # \"labels\": encoded_dict['input_ids'],\n            }\n        elif \"chatglm\" in self.args.actor_model_path:\n            prompt = \"\\n\\n\".join((system, prompt))\n            prompt = f\"[Round {1}]\\n\\n问：{prompt}\\n\\n答：\" if \"chatglm2\" in self.args.actor_model_path else f\"[Round {0}]\\n问：{prompt}\\n答：\"\n            encoded_dict = self.tokenizer(prompt, max_length=self.args.max_prompt_length,\n                                          return_tensors=\"pt\", truncation=\"only_first\")\n            return {\n                \"input_ids\": encoded_dict['input_ids'][0],\n            }\n        elif \"glm\" in self.args.actor_model_path:\n            # encoded_prompt = self.tokenizer(prompt, prefix + self.tokenizer.mask_token)\n            # prompt_length = len(encoded_prompt['input_ids'])\n            encoded_dict = self.tokenizer(prompt, prefix + self.tokenizer.mask_token,\n                                          max_length=self.args.max_prompt_length,\n                                          # padding=\"max_length\",\n                                          truncation=\"only_first\",\n                                          return_tensors=\"pt\",\n                                          return_token_type_ids=False)\n            encoded_dict = self.tokenizer.build_inputs_for_generation(encoded_dict,\n                                                                      max_gen_length=self.args.max_gen_length,\n                                                                      padding=True)\n\n            return {\n                \"input_ids\": encoded_dict['input_ids'][0],\n                \"position_ids\": encoded_dict['position_ids'][0],\n                \"generation_attention_mask\": encoded_dict['generation_attention_mask'][0],\n                # \"labels\": encoded_dict['labels'][0],\n            }\n        else:\n            raise ValueError(f\"Unsupported model name: {self.args.model_name_or_path}\")\n\n    @staticmethod\n    def load_dataset(filename):\n        discard = 0\n        datasets = []\n        with open(filename, \"r\", encoding=\"utf-8\") as f:\n            for i, line in tqdm(enumerate(f), desc=f\"Loading {os.path.basename(filename)}\"):\n                item = json.loads(line)\n                data_type = item.get('data_type', \"human_generated\")\n                if data_type != \"human_generated\":\n                    continue\n                prompt = str(item['prompt'])\n                prefix = item.get('prefix', \"\")\n                system = item.get('system', \"\")\n\n                if len(prompt) <= 0:\n                    discard += 1\n                    continue\n                datasets.append({\"prompt\": prompt, \"system\": system, \"prefix\": prefix})\n        print_rank_0(\n            f\"Finished loading {os.path.basename(filename)}, # samples: {len(datasets)}, # discarded: {discard}\")\n\n        return datasets\n\n\nclass PPODataset:\n    def __init__(self, max_size, small_batch_size):\n        self.dataset = []\n        self.max_size = max_size\n        self.small_batch_size = small_batch_size\n\n    def separate(self):\n        small_dataset = []\n        for large_batch in self.dataset:\n            if type(large_batch) == list or type(large_batch) == tuple:\n                large_size = len(large_batch[0])\n            elif type(large_batch) == dict:\n                large_size = len(large_batch[list(large_batch.keys())[0]])\n            else:\n                large_size = len(large_batch)\n            for i in range(0, large_size, self.small_batch_size):\n                if type(large_batch) == list or type(large_batch) == tuple:\n                    small_dataset.append(\n                        [x[i:i + self.small_batch_size] for x in large_batch])\n                elif type(large_batch) == dict:\n                    small_dataset.append({\n                        k: v[i:i + self.small_batch_size] if v is not None else None\n                        for k, v in large_batch.items()\n                    })\n                else:\n                    small_dataset.append(large_batch[i:i + self.small_batch_size])\n        self.free()\n\n        return small_dataset\n\n    def add(self, data):\n        if len(self.dataset) < self.max_size:\n            self.dataset.append(data)\n            if len(self.dataset) == self.max_size:\n                return self.separate()\n            else:\n                return None\n        else:\n            raise ValueError(\n                \"The dataset is full but we did not stop it. There is a bug in the code.\"\n            )\n\n    def free(self):\n        self.dataset = []\n\n\nclass DPODataset(Dataset):\n    def __init__(self, args, filename, tokenizer):\n        self.pairs = self.load_dataset(filename)\n        self.args = args\n        self.tokenizer = tokenizer\n\n        for k in range(5):\n            print_rank_0(f\"DPODataset sample-{k}\\n: {self.pairs[k]}\")\n\n    def __len__(self):\n        return len(self.pairs)\n\n    def __getitem__(self, idx):\n        pair = self.pairs[idx]\n\n        index = pair[\"index\"]\n        prompt = pair[\"prompt\"]\n        chosen_answer = pair[\"chosen_answer\"]\n        rejected_answer = pair[\"rejected_answer\"]\n        prefix = pair['prefix']\n        system = pair['system']\n        if \"pangu\" in self.args.model_name_or_path.lower():\n            chosen_encodings_dict = self.tokenizer(prompt, prefix + chosen_answer, max_length=self.args.max_length,\n                                                   truncation=\"longest_first\", padding=\"max_length\",\n                                                   return_tensors=\"pt\",\n                                                   return_token_type_ids=False)\n            rejected_encodings_dict = self.tokenizer(prompt, prefix + rejected_answer, max_length=self.args.max_length,\n                                                     truncation=\"longest_first\", padding=\"max_length\",\n                                                     return_tensors=\"pt\",\n                                                     return_token_type_ids=False)\n            return {\n                \"chosen_input_ids\": chosen_encodings_dict[\"input_ids\"],\n                \"chosen_attention_mask\": chosen_encodings_dict[\"attention_mask\"],\n                \"rejected_input_ids\": rejected_encodings_dict[\"input_ids\"],\n                \"rejected_attention_mask\": rejected_encodings_dict[\"attention_mask\"],\n                \"labels\": rejected_encodings_dict[\"input_ids\"],\n            }\n        elif \"chatglm3\" in self.args.model_name_or_path.lower():\n            chosen_input_ids, chosen_labels, _ = chatglm3_encode(self.tokenizer, prompt, chosen_answer, system,\n                                                                 self.args.max_length)\n            rejected_input_ids, rejected_labels, _ = chatglm3_encode(self.tokenizer, prompt, rejected_answer, system,\n                                                                     self.args.max_length)\n            return {\n                \"index\": torch.tensor(index, dtype=torch.long),\n                \"chosen_input_ids\": torch.tensor(chosen_input_ids, dtype=torch.long),\n                \"rejected_input_ids\": torch.tensor(rejected_input_ids, dtype=torch.long),\n                \"chosen_labels\": torch.tensor(chosen_labels, dtype=torch.long),\n                \"rejected_labels\": torch.tensor(rejected_labels, dtype=torch.long)\n            }\n        elif \"chatglm2\" in self.args.model_name_or_path.lower():\n            chosen_input_ids, chosen_labels, _ = chatglm2_encode(self.tokenizer, prompt, chosen_answer, system,\n                                                                 self.args.max_length)\n            rejected_input_ids, rejected_labels, _ = chatglm2_encode(self.tokenizer, prompt, rejected_answer, system,\n                                                                     self.args.max_length)\n            return {\n                \"index\": torch.tensor(index, dtype=torch.long),\n                \"chosen_input_ids\": torch.tensor(chosen_input_ids, dtype=torch.long),\n                \"rejected_input_ids\": torch.tensor(rejected_input_ids, dtype=torch.long),\n                \"chosen_labels\": torch.tensor(chosen_labels, dtype=torch.long),\n                \"rejected_labels\": torch.tensor(rejected_labels, dtype=torch.long)\n            }\n        elif \"chatglm\" in self.args.model_name_or_path.lower():\n            prompt = f\"[Round {0}]\\n问：{prompt}\\n答：\"\n            chosen_encodings_dict = self.tokenizer(prompt, chosen_answer, max_length=self.args.max_length,\n                                                   truncation=\"longest_first\", padding=\"max_length\",\n                                                   return_tensors=\"pt\")\n            rejected_encodings_dict = self.tokenizer(prompt, rejected_answer, max_length=self.args.max_length,\n                                                     truncation=\"longest_first\", padding=\"max_length\",\n                                                     return_tensors=\"pt\")\n            return {\n                \"chosen_input_ids\": chosen_encodings_dict[\"input_ids\"][0],\n                \"rejected_input_ids\": rejected_encodings_dict[\"input_ids\"][0],\n                \"labels\": rejected_encodings_dict[\"input_ids\"][0],\n            }\n        else:\n            raise ValueError(f\"Unsupported model name: {self.args.model_name_or_path}\")\n\n    @staticmethod\n    def load_dataset(filename):\n        discard = 0\n        index = 1\n        pairs = []\n        with open(filename, \"r\", encoding=\"utf-8\") as f:\n            for line in tqdm(f, desc=f\"Loading {os.path.basename(filename)}\"):\n                item = json.loads(line)\n                prompt = str(item['prompt'])\n                answers = item['answers']\n                prefix = item.get('prefix', \"\")\n                system = item.get('system', \"\")\n                chosen_answer, rejected_answer = None, None\n                for i in range(len(answers) - 1):\n                    answer_1 = str(answers[i][\"answer\"])\n                    answer_1_score = answers[i][\"score\"]\n                    answer_2 = str(answers[i + 1][\"answer\"])\n                    answer_2_score = answers[i + 1][\"score\"]\n                    if answer_1_score > answer_2_score:\n                        chosen_answer = answer_1\n                    rejected_answer = answer_2\n                    if chosen_answer is not None and rejected_answer is not None \\\n                            and len(prompt) > 0 and len(chosen_answer) > 0 and len(rejected_answer) > 0 \\\n                            and chosen_answer != rejected_answer:\n                        pair = {\n                            \"index\": index,\n                            \"prompt\": prompt,\n                            \"prefix\": prefix,\n                            \"system\": system,\n                            \"chosen_answer\": chosen_answer,\n                            \"rejected_answer\": rejected_answer\n                        }\n                        index += 1\n                        pairs.append(pair)\n                    else:\n                        discard += 1\n\n        print_rank_0(f\"Finished loading {os.path.basename(filename)}, # pairs: {len(pairs)}, # discarded: {discard}\")\n\n        return pairs\n\n\nclass OCNLIDataset(Dataset):\n    def __init__(self, args, eval_filename, tokenizer, train_filename=None):\n        self.tokenizer = tokenizer\n        self.args = args\n        self.label_dict = {'entailment': 'Yes', 'neutral': 'Maybe', 'contradiction': 'No'}\n\n        dataset = self.load_dataset(eval_filename)\n        if train_filename is not None:\n            self.labelled_list = self.load_dataset(eval_filename)\n        self.post_list = dataset\n\n        for k in range(5):\n            print_rank_0(f\"OCNLIDataset sample-{k}\\n: {dataset[k]}\")\n\n    def __len__(self):\n        return len(self.post_list)\n\n    def __getitem__(self, idx):\n        data = self.post_list[idx]\n        prompt = data['prompt']\n        label = data['label']\n\n        # Few-Shot example construction\n        if hasattr(self, \"labelled_list\"):\n            examples = random.sample(self.labelled_list, min(len(self.labelled_list), self.args.max_few_shot))\n            prompts = []\n            prompt_tokens = self.tokenizer.tokenize(prompt)\n            for example in examples:\n                example_prompt = example['prompt']\n                exmample_tokens = self.tokenizer.tokenize(example_prompt + \"\\n\")\n                if len(exmample_tokens) + len(prompt_tokens) + 2 > self.args.max_length:\n                    break\n                else:\n                    prompts.append(example_prompt)\n                    prompt_tokens.extend(exmample_tokens)\n            prompts.append(prompt)\n            prompt = \"\\n\".join(prompts)\n\n        encoded_dict = self.tokenizer(prompt, max_length=self.args.max_length,\n                                      padding=\"max_length\", truncation=\"longest_first\", return_tensors=\"pt\")\n\n        return {\n            \"input_ids\": encoded_dict[\"input_ids\"],\n            \"attention_mask\": encoded_dict[\"attention_mask\"],\n            \"labels\": encoded_dict[\"input_ids\"],\n            \"label_str\": label\n        }\n\n    def load_dataset(self, filename):\n        discard = 0\n        datasets = []\n        with open(filename, \"r\", encoding=\"utf-8\") as f:\n            for i, line in tqdm(enumerate(f), desc=f\"Loading {os.path.basename(filename)}\"):\n                item = json.loads(line)\n                s1 = item['sentence1']\n                s2 = item['sentence2']\n                label = item['label']\n                # 标注结果有冲突，则忽略\n                if label == \"-\":\n                    continue\n                for l in self.label_dict.values():\n                    prompt = f'{s1}?{l}，{s2}'\n                    if len(prompt) <= 0:\n                        continue\n                    datasets.append({\"prompt\": prompt, \"label\": self.label_dict[label]})\n\n        print_rank_0(f\"Finished loading {os.path.basename(filename)}, # discarded: {discard}\")\n\n        return datasets\n\n\nclass CMNLIDataset(Dataset):\n    def __init__(self, args, eval_filename, tokenizer, train_filename=None):\n        self.tokenizer = tokenizer\n        self.args = args\n        self.label_dict = {'entailment': 'Yes', 'neutral': 'Maybe', 'contradiction': 'No'}\n\n        dataset = self.load_dataset(eval_filename)\n        if train_filename is not None:\n            self.labelled_list = self.load_dataset(eval_filename)\n        self.post_list = dataset\n\n        for k in range(5):\n            print_rank_0(f\"CMNLIDataset sample-{k}\\n: {dataset[k]}\")\n\n    def __len__(self):\n        return len(self.post_list)\n\n    def __getitem__(self, idx):\n        data = self.post_list[idx]\n        prompt = data['prompt']\n        label = data['label']\n\n        # Few-Shot example construction\n        if hasattr(self, \"labelled_list\"):\n            examples = random.sample(self.labelled_list, min(len(self.labelled_list), self.args.max_few_shot))\n            prompts = []\n            prompt_tokens = self.tokenizer.tokenize(prompt)\n            for example in examples:\n                example_prompt = example['prompt']\n                exmample_tokens = self.tokenizer.tokenize(example_prompt + \"\\n\")\n                if len(exmample_tokens) + len(prompt_tokens) + 2 > self.args.max_length:\n                    break\n                else:\n                    prompts.append(example_prompt)\n                    prompt_tokens.extend(exmample_tokens)\n            prompts.append(prompt)\n            prompt = \"\\n\".join(prompts)\n\n        encoded_dict = self.tokenizer(prompt, max_length=self.args.max_length,\n                                      padding=\"max_length\", truncation=\"longest_first\", return_tensors=\"pt\")\n        # label_dict = self.tokenizer(label, max_length=self.args.max_length, add_special_tokens=False,\n        #                             return_attention_mask=False, return_token_type_ids=False, return_tensors=\"pt\")\n\n        return {\n            \"input_ids\": encoded_dict[\"input_ids\"],\n            \"attention_mask\": encoded_dict[\"attention_mask\"],\n            \"labels\": encoded_dict[\"input_ids\"],\n            \"label_str\": label\n        }\n\n    def load_dataset(self, filename):\n        discard = 0\n        datasets = []\n        with open(filename, \"r\", encoding=\"utf-8\") as f:\n            for i, line in tqdm(enumerate(f), desc=f\"Loading {os.path.basename(filename)}\"):\n                item = json.loads(line)\n                s1 = item['sentence1']\n                s2 = item['sentence2']\n                label = item['label']\n                # 标注结果有冲突，则忽略\n                if label == \"-\":\n                    continue\n                for l in self.label_dict.values():\n                    prompt = f'{s1}?{l}，{s2}'\n                    if len(prompt) <= 0:\n                        continue\n                    datasets.append({\"prompt\": prompt, \"label\": self.label_dict[label]})\n\n        print_rank_0(f\"Finished loading {os.path.basename(filename)}, # discarded: {discard}\")\n\n        return datasets\n\n\nclass CHIDDataset(Dataset):\n    def __init__(self, args, eval_filename, tokenizer, train_filename=None):\n        self.tokenizer = tokenizer\n        self.args = args\n\n        self.idiom_dict = self.load_idiom_dict()\n        dataset = self.load_dataset(eval_filename)\n        if train_filename is not None:\n            self.labelled_list = self.load_dataset(eval_filename)\n        self.post_list = dataset\n\n        for k in range(5):\n            print_rank_0(f\"CHIDDataset sample-{k}\\n: {dataset[k]}\")\n\n    def __len__(self):\n        return len(self.post_list)\n\n    def __getitem__(self, idx):\n        data = self.post_list[idx]\n        prompt = data['prompt']\n        label = data['label']\n        candidates = data['candidates']\n\n        # Few-Shot example construction\n        if hasattr(self, \"labelled_list\"):\n            examples = random.sample(self.labelled_list, min(len(self.labelled_list), self.args.max_few_shot))\n            prompts = []\n            prompt_tokens = self.tokenizer.tokenize(prompt)\n            for example in examples:\n                example_prompt = example['prompt']\n                exmample_tokens = self.tokenizer.tokenize(example_prompt + \"\\n\")\n                if len(exmample_tokens) + len(prompt_tokens) + 2 > self.args.max_length:\n                    break\n                else:\n                    prompts.append(example_prompt)\n                    prompt_tokens.extend(exmample_tokens)\n            prompts.append(prompt)\n            prompt = \"\\n\".join(prompts)\n\n        encoded_dict = self.tokenizer(prompt, max_length=self.args.max_length,\n                                      padding=\"max_length\", truncation=\"longest_first\", return_tensors=\"pt\")\n        # label_dict = self.tokenizer(label, max_length=self.args.max_length, add_special_tokens=False,\n        #                             return_attention_mask=False, return_token_type_ids=False, return_tensors=\"pt\")\n\n        return {\n            \"input_ids\": encoded_dict[\"input_ids\"],\n            \"attention_mask\": encoded_dict[\"attention_mask\"],\n            \"labels\": encoded_dict[\"input_ids\"],\n            \"label_str\": label,\n            \"candidates\": candidates\n        }\n\n    def load_dataset(self, filename):\n        discard = 0\n        datasets = []\n        with open(filename, \"r\", encoding=\"utf-8\") as f:\n            for i, line in tqdm(enumerate(f), desc=f\"Loading {os.path.basename(filename)}\"):\n                item = json.loads(line)\n                candidates = item['candidates']\n                contents = item['content']\n                for content in contents:\n                    for idiom in re.findall(r\"#idiom\\d+#\", content):\n                        label = candidates[self.idiom_dict[idiom]]\n                        for candidate in candidates:\n                            prompt = content.replace(idiom, candidate)\n                            if len(prompt) <= 0:\n                                continue\n                            datasets.append({\"prompt\": prompt, \"label\": label, \"candidates\": candidates})\n\n        print_rank_0(f\"Finished loading {os.path.basename(filename)}, # discarded: {discard}\")\n\n        return datasets\n\n    def load_idiom_dict(self):\n        idiom_dict = json.load(open(os.path.join(self.args.data_dir, \"dev_answer.json\"), \"r\", encoding=\"utf-8\"))\n        idiom_dict.update(json.load(open(os.path.join(self.args.data_dir, \"train_answer.json\"), \"r\", encoding=\"utf-8\")))\n\n        print_rank_0(f\"Finished loading idiom dict\")\n\n        return idiom_dict\n\n\nclass CMRCDataset(Dataset):\n    def __init__(self, args, eval_filename, tokenizer, train_filename=None):\n        self.tokenizer = tokenizer\n        self.args = args\n\n        dataset = self.load_dataset(eval_filename)\n        if train_filename is not None:\n            self.labelled_list = self.load_dataset(eval_filename)\n        self.post_list = dataset\n\n        for k in range(5):\n            print_rank_0(f\"CMRCDataset sample-{k}\\n: {dataset[k]}\")\n\n    def __len__(self):\n        return len(self.post_list)\n\n    def __getitem__(self, idx):\n        data = self.post_list[idx]\n        prompt = data['prompt']\n        label = data['label']\n\n        # Few-Shot example construction\n        if hasattr(self, \"labelled_list\"):\n            examples = random.sample(self.labelled_list, min(len(self.labelled_list), self.args.max_few_shot))\n            prompts = []\n            prompt_tokens = self.tokenizer.tokenize(prompt)\n            for example in examples:\n                example_prompt = example['prompt']\n                exmample_tokens = self.tokenizer.tokenize(example_prompt + \"\\n\")\n                if len(exmample_tokens) + len(prompt_tokens) + 2 > self.args.max_length:\n                    break\n                else:\n                    prompts.append(example_prompt)\n                    prompt_tokens.extend(exmample_tokens)\n            prompts.append(prompt)\n            prompt = \"\\n\".join(prompts)\n\n        encoded_dict = self.tokenizer(prompt, max_length=self.args.max_length,\n                                      padding=\"max_length\", truncation=\"longest_first\", return_tensors=\"pt\")\n\n        return {\n            \"input_ids\": encoded_dict[\"input_ids\"],\n            \"attention_mask\": encoded_dict[\"attention_mask\"],\n            \"labels\": encoded_dict[\"input_ids\"],\n            \"label_str\": label\n        }\n\n    def load_dataset(self, filename):\n        discard = 0\n        datasets = []\n        data = json.load(open(filename, \"r\", encoding=\"utf-8\"))\n        for paragraphs in data['data']:\n            for paragraph in paragraphs['paragraphs']:\n                context = paragraph['context']\n                for qs in paragraph['qas']:\n                    question = qs['question']\n                    answers = []\n                    [answers.append(answer) for answer in qs['answers'] if answer not in answers]\n                    prompt_template = \"阅读文章：{context}\\n问：{question}\\n答：\"\n                    prompt = prompt_template.format(context=context, question=question)\n                    if len(prompt) <= 0:\n                        continue\n                    # if len(prompt) > self.args.max_length:\n                    #     idx = len(prompt) - self.args.max_length\n                    #     prompt = prompt_template.format(context=context[:-idx], question=question)\n                    datasets.append({\"prompt\": prompt, \"label\": answers})\n\n        print_rank_0(f\"Finished loading {os.path.basename(filename)}, # discarded: {discard}\")\n\n        return datasets\n\n\nclass CLUEWSCDataset(Dataset):\n    def __init__(self, args, eval_filename, tokenizer, train_filename=None):\n        self.tokenizer = tokenizer\n        self.args = args\n        self.label_dict = {'true': '1', 'false': '0'}\n\n        dataset = self.load_dataset(eval_filename)\n        if train_filename is not None:\n            self.labelled_list = self.load_dataset(eval_filename)\n        self.post_list = dataset\n\n        for k in range(5):\n            print_rank_0(f\"CLUEWSCDataset sample-{k}\\n: {dataset[k]}\")\n\n    def __len__(self):\n        return len(self.post_list)\n\n    def __getitem__(self, idx):\n        data = self.post_list[idx]\n        prompt = data['prompt']\n        label = data['label']\n\n        # Few-Shot example construction\n        if hasattr(self, \"labelled_list\"):\n            examples = random.sample(self.labelled_list, min(len(self.labelled_list), self.args.max_few_shot))\n            prompts = []\n            prompt_tokens = self.tokenizer.tokenize(prompt)\n            for example in examples:\n                example_prompt = example['prompt']\n                exmample_tokens = self.tokenizer.tokenize(example_prompt + \"\\n\")\n                if len(exmample_tokens) + len(prompt_tokens) + 2 > self.args.max_length:\n                    break\n                else:\n                    prompts.append(example_prompt)\n                    prompt_tokens.extend(exmample_tokens)\n            prompts.append(prompt)\n            prompt = \"\\n\".join(prompts)\n\n        encoded_dict = self.tokenizer(prompt, max_length=self.args.max_length,\n                                      padding=\"max_length\", truncation=\"longest_first\", return_tensors=\"pt\")\n        # label_dict = self.tokenizer(label, max_length=self.args.max_length, add_special_tokens=False,\n        #                             return_attention_mask=False, return_token_type_ids=False, return_tensors=\"pt\")\n\n        return {\n            \"input_ids\": encoded_dict[\"input_ids\"],\n            \"attention_mask\": encoded_dict[\"attention_mask\"],\n            \"labels\": encoded_dict[\"input_ids\"],\n            \"label_str\": label,\n        }\n\n    def load_dataset(self, filename):\n        discard = 0\n        datasets = []\n        with open(filename, \"r\", encoding=\"utf-8\") as f:\n            for i, line in tqdm(enumerate(f), desc=f\"Loading {os.path.basename(filename)}\"):\n                item = json.loads(line)\n                text = item['text']\n                span2_index = item['target']['span2_index']\n                span2_text = item['target']['span2_text']\n                span1_text = item['target']['span1_text']\n                label = self.label_dict[item['label']]\n                prompt = text[:span2_index] + span1_text + text[span2_index + len(span2_text):]\n                if len(prompt) <= 0:\n                    continue\n                datasets.append({\"prompt\": prompt, \"label\": label})\n\n        print_rank_0(f\"Finished loading {os.path.basename(filename)}, # discarded: {discard}\")\n\n        return datasets\n\n\nclass C3Dataset(Dataset):\n    def __init__(self, args, eval_filename, tokenizer, train_filename=None):\n        self.tokenizer = tokenizer\n        self.args = args\n\n        dataset = self.load_dataset(eval_filename)\n        if train_filename is not None:\n            self.labelled_list = self.load_dataset(eval_filename)\n        self.post_list = dataset\n\n        for k in range(5):\n            print_rank_0(f\"C3Dataset sample-{k}\\n: {dataset[k]}\")\n\n    def __len__(self):\n        return len(self.post_list)\n\n    def __getitem__(self, idx):\n        data = self.post_list[idx]\n        prompt = data['prompt']\n        label = data['label']\n        candidates = data['candidates']\n\n        # Few-Shot example construction\n        if hasattr(self, \"labelled_list\"):\n            examples = random.sample(self.labelled_list, min(len(self.labelled_list), self.args.max_few_shot))\n            prompts = []\n            prompt_tokens = self.tokenizer.tokenize(prompt)\n            for example in examples:\n                example_prompt = example['prompt']\n                exmample_tokens = self.tokenizer.tokenize(example_prompt + \"\\n\")\n                if len(exmample_tokens) + len(prompt_tokens) + 2 > self.args.max_length:\n                    break\n                else:\n                    prompts.append(example_prompt)\n                    prompt_tokens.extend(exmample_tokens)\n            prompts.append(prompt)\n            prompt = \"\\n\".join(prompts)\n\n        encoded_dict = self.tokenizer(prompt, max_length=self.args.max_length,\n                                      padding=\"max_length\", truncation=\"longest_first\", return_tensors=\"pt\")\n        # label_dict = self.tokenizer(label, max_length=self.args.max_length, add_special_tokens=False,\n        #                             return_attention_mask=False, return_token_type_ids=False, return_tensors=\"pt\")\n\n        return {\n            \"input_ids\": encoded_dict[\"input_ids\"],\n            \"attention_mask\": encoded_dict[\"attention_mask\"],\n            \"labels\": encoded_dict[\"input_ids\"],\n            \"label_str\": label,\n            \"candidates\": candidates\n        }\n\n    def load_dataset(self, filename):\n        discard = 0\n        datasets = []\n\n        data = json.load(open(filename, \"r\", encoding=\"utf-8\"))\n        for i, d in enumerate(data):\n            context = \"\".join(d[0])\n            for qs in d[1]:\n                question = qs['question']\n                choices = qs['choice']\n                choices_padded = [choices[i] if i < len(choices) else f\"test{i}\" for i in range(4)]\n                answer = qs['answer']\n                for choice in choices:\n                    prompt = f\"问: {question}\\n答:{choice}\\n该答案来自对话: {context}\"\n                    if len(prompt) <= 0:\n                        continue\n                    datasets.append({\"prompt\": prompt, \"label\": answer, \"candidates\": choices_padded})\n\n        print_rank_0(f\"Finished loading {os.path.basename(filename)}, # discarded: {discard}\")\n\n        return datasets\n\n\nclass AFQMCDataset(Dataset):\n    def __init__(self, args, eval_filename, tokenizer, train_filename=None):\n        self.tokenizer = tokenizer\n        self.args = args\n        self.label_dict = {'0': '不同', '1': '相同'}\n\n        dataset = self.load_dataset(eval_filename)\n        if train_filename is not None:\n            self.labelled_list = self.load_dataset(eval_filename)\n        self.post_list = dataset\n\n        for k in range(5):\n            print_rank_0(f\"AFQMCDataset sample-{k}\\n: {dataset[k]}\")\n\n    def __len__(self):\n        return len(self.post_list)\n\n    def __getitem__(self, idx):\n        data = self.post_list[idx]\n        prompt = data['prompt']\n        label = data['label']\n\n        # Few-Shot example construction\n        if hasattr(self, \"labelled_list\"):\n            examples = random.sample(self.labelled_list, min(len(self.labelled_list), self.args.max_few_shot))\n            prompts = []\n            prompt_tokens = self.tokenizer.tokenize(prompt)\n            for example in examples:\n                example_prompt = example['prompt']\n                exmample_tokens = self.tokenizer.tokenize(example_prompt + \"\\n\")\n                if len(exmample_tokens) + len(prompt_tokens) + 2 > self.args.max_length:\n                    break\n                else:\n                    prompts.append(example_prompt)\n                    prompt_tokens.extend(exmample_tokens)\n            prompts.append(prompt)\n            prompt = \"\\n\".join(prompts)\n\n        encoded_dict = self.tokenizer(prompt, max_length=self.args.max_length,\n                                      padding=\"max_length\", truncation=\"longest_first\", return_tensors=\"pt\")\n        # label_dict = self.tokenizer(label, max_length=self.args.max_length, add_special_tokens=False,\n        #                             return_attention_mask=False, return_token_type_ids=False, return_tensors=\"pt\")\n\n        return {\n            \"input_ids\": encoded_dict[\"input_ids\"],\n            \"attention_mask\": encoded_dict[\"attention_mask\"],\n            \"labels\": encoded_dict[\"input_ids\"],\n            \"label_str\": label\n        }\n\n    def load_dataset(self, filename):\n        discard = 0\n        datasets = []\n        with open(filename, \"r\", encoding=\"utf-8\") as f:\n            for i, line in tqdm(enumerate(f), desc=f\"Loading {os.path.basename(filename)}\"):\n                item = json.loads(line)\n                s1 = item['sentence1']\n                s2 = item['sentence2']\n                label = self.label_dict[item['label']]\n                for l in self.label_dict.values():\n                    prompt = f'下面两个句子语义{l}:{s1}。{s2}'\n                    if len(prompt) <= 0:\n                        continue\n                    datasets.append({\"prompt\": prompt, \"label\": label})\n\n        print_rank_0(f\"Finished loading {os.path.basename(filename)}, # discarded: {discard}\")\n\n        return datasets\n\n\nclass CSLDataset(Dataset):\n    def __init__(self, args, eval_filename, tokenizer, train_filename=None):\n        self.tokenizer = tokenizer\n        self.args = args\n        self.label_dict = {'0': '不是', '1': '是'}\n\n        dataset = self.load_dataset(eval_filename)\n        if train_filename is not None:\n            self.labelled_list = self.load_dataset(eval_filename)\n        self.post_list = dataset\n\n        for k in range(5):\n            print_rank_0(f\"CSLDataset sample-{k}\\n: {dataset[k]}\")\n\n    def __len__(self):\n        return len(self.post_list)\n\n    def __getitem__(self, idx):\n        data = self.post_list[idx]\n        prompt = data['prompt']\n        label = data['label']\n\n        # Few-Shot example construction\n        if hasattr(self, \"labelled_list\"):\n            examples = random.sample(self.labelled_list, min(len(self.labelled_list), self.args.max_few_shot))\n            prompts = []\n            prompt_tokens = self.tokenizer.tokenize(prompt)\n            for example in examples:\n                example_prompt = example['prompt']\n                exmample_tokens = self.tokenizer.tokenize(example_prompt + \"\\n\")\n                if len(exmample_tokens) + len(prompt_tokens) + 2 > self.args.max_length:\n                    break\n                else:\n                    prompts.append(example_prompt)\n                    prompt_tokens.extend(exmample_tokens)\n            prompts.append(prompt)\n            prompt = \"\\n\".join(prompts)\n\n        encoded_dict = self.tokenizer(prompt, max_length=self.args.max_length,\n                                      padding=\"max_length\", truncation=\"longest_first\", return_tensors=\"pt\")\n        # label_dict = self.tokenizer(label, max_length=self.args.max_length, add_special_tokens=False,\n        #                             return_attention_mask=False, return_token_type_ids=False, return_tensors=\"pt\")\n\n        return {\n            \"input_ids\": encoded_dict[\"input_ids\"],\n            \"attention_mask\": encoded_dict[\"attention_mask\"],\n            \"labels\": encoded_dict[\"input_ids\"],\n            \"label_str\": label\n        }\n\n    def load_dataset(self, filename):\n        discard = 0\n        datasets = []\n        with open(filename, \"r\", encoding=\"utf-8\") as f:\n            for i, line in tqdm(enumerate(f), desc=f\"Loading {os.path.basename(filename)}\"):\n                item = json.loads(line)\n                abstract = item['abst']\n                keyword = \"+\".join(item['keyword'])\n                label = self.label_dict[item['label']]\n                for l in self.label_dict.values():\n                    prompt = f'摘要:{abstract}，关键词:{keyword}{l}真实关键词'\n                    if len(prompt) <= 0:\n                        continue\n                    datasets.append({\"prompt\": prompt, \"label\": label})\n\n        print_rank_0(f\"Finished loading {os.path.basename(filename)}, # discarded: {discard}\")\n\n        return datasets\n\n\nclass IFLYTEKDataset(Dataset):\n    def __init__(self, args, eval_filename, tokenizer, train_filename=None):\n        self.tokenizer = tokenizer\n        self.args = args\n        self.label_dict = {'0': '打车', '1': '地图导航', '2': '免费WIFI', '3': '租车', '4': '同城服务', '5': '快递物流',\n                           '6': '婚庆', '7': '家政', '8': '公共交通', '9': '政务', '10': '社区服务', '11': '薅羊毛',\n                           '12': '魔幻', '13': '仙侠', '14': '卡牌', '15': '飞行空战', '16': '射击游戏',\n                           '17': '休闲益智', '18': '动作类', '19': '体育竞技', '20': '棋牌中心', '21': '经营养成',\n                           '22': '策略', '23': 'MOBA', '24': '辅助工具', '25': '约会社交', '26': '即时通讯',\n                           '27': '工作社交', '28': '论坛圈子', '29': '婚恋社交', '30': '情侣社交', '31': '社交工具',\n                           '32': '生活社交', '33': '微博博客', '34': '新闻', '35': '漫画', '36': '小说', '37': '技术',\n                           '38': '教辅', '39': '问答交流', '40': '搞笑', '41': '杂志', '42': '百科', '43': '影视娱乐',\n                           '44': '求职', '45': '兼职', '46': '视频', '47': '短视频', '48': '音乐', '49': '直播',\n                           '50': '电台', '51': 'K歌', '52': '成人', '53': '中小学', '54': '职考', '55': '公务员',\n                           '56': '英语', '57': '视频教育', '58': '高等教育', '59': '成人教育', '60': '艺术',\n                           '61': '语言(非英语)', '62': '旅游资讯', '63': '综合预定', '64': '民航', '65': '铁路',\n                           '66': '酒店', '67': '行程管理', '68': '民宿短租', '69': '出国', '70': '工具',\n                           '71': '亲子儿童', '72': '母婴', '73': '驾校', '74': '违章', '75': '汽车咨询',\n                           '76': '汽车交易', '77': '日常养车', '78': '行车辅助', '79': '租房', '80': '买房',\n                           '81': '装修家居', '82': '电子产品', '83': '问诊挂号', '84': '养生保健', '85': '医疗服务',\n                           '86': '减肥瘦身', '87': '美妆美业', '88': '菜谱', '89': '餐饮店', '90': '体育咨讯',\n                           '91': '运动健身', '92': '支付', '93': '保险', '94': '股票', '95': '借贷', '96': '理财',\n                           '97': '彩票', '98': '记账', '99': '银行', '100': '美颜', '101': '影像剪辑',\n                           '102': '摄影修图', '103': '相机', '104': '绘画', '105': '二手', '106': '电商', '107': '团购',\n                           '108': '外卖', '109': '电影票务', '110': '社区超市', '111': '购物咨询', '112': '笔记',\n                           '113': '办公', '114': '日程管理', '115': '女性', '116': '经营', '117': '收款', '118': '其他'}\n\n        dataset = self.load_dataset(eval_filename)\n        if train_filename is not None:\n            self.labelled_list = self.load_dataset(eval_filename)\n        self.post_list = dataset\n\n        for k in range(5):\n            print_rank_0(f\"IFLYTEKDataset sample-{k}\\n: {dataset[k]}\")\n\n    def __len__(self):\n        return len(self.post_list)\n\n    def __getitem__(self, idx):\n        data = self.post_list[idx]\n        prompt = data['prompt']\n        label = data['label']\n        candidates = data['candidates']\n\n        # Few-Shot example construction\n        if hasattr(self, \"labelled_list\"):\n            examples = random.sample(self.labelled_list, min(len(self.labelled_list), self.args.max_few_shot))\n            prompts = []\n            prompt_tokens = self.tokenizer.tokenize(prompt)\n            for example in examples:\n                example_prompt = example['prompt']\n                exmample_tokens = self.tokenizer.tokenize(example_prompt + \"\\n\")\n                if len(exmample_tokens) + len(prompt_tokens) + 2 > self.args.max_length:\n                    break\n                else:\n                    prompts.append(example_prompt)\n                    prompt_tokens.extend(exmample_tokens)\n            prompts.append(prompt)\n            prompt = \"\\n\".join(prompts)\n\n        encoded_dict = self.tokenizer(prompt, max_length=self.args.max_length,\n                                      padding=\"max_length\", truncation=\"longest_first\", return_tensors=\"pt\")\n        # label_dict = self.tokenizer(label, max_length=self.args.max_length, add_special_tokens=False,\n        #                             return_attention_mask=False, return_token_type_ids=False, return_tensors=\"pt\")\n\n        return {\n            \"input_ids\": encoded_dict[\"input_ids\"],\n            \"attention_mask\": encoded_dict[\"attention_mask\"],\n            \"labels\": encoded_dict[\"input_ids\"],\n            \"label_str\": label,\n            \"candidates\": candidates\n        }\n\n    def load_dataset(self, filename):\n        discard = 0\n        datasets = []\n        with open(filename, \"r\", encoding=\"utf-8\") as f:\n            for i, line in tqdm(enumerate(f), desc=f\"Loading {os.path.basename(filename)}\"):\n                item = json.loads(line)\n                content = item['sentence']\n                label = item['label_des']\n                # randomly sample 3 categories as negative sample\n                labels = set(self.label_dict.values())\n                labels.remove(label)\n                candidates = [label] + random.sample(labels, 3)\n                for l in candidates:\n                    prompt = f'这是关于{l}的应用程序:{content}'\n                    if len(prompt) <= 0:\n                        continue\n                    datasets.append({\"prompt\": prompt, \"label\": label, \"candidates\": candidates})\n\n        print_rank_0(f\"Finished loading {os.path.basename(filename)}, # discarded: {discard}\")\n\n        return datasets\n\n\nclass TNEWSDataset(Dataset):\n    def __init__(self, args, eval_filename, tokenizer, train_filename=None):\n        self.tokenizer = tokenizer\n        self.args = args\n        self.label_dict = {'100': '故事',\n                           '101': '文化',\n                           '102': '娱乐',\n                           '103': '体育',\n                           '104': '财经',\n                           '106': '房产',\n                           '107': '汽车',\n                           '108': '教育',\n                           '109': '科技',\n                           '110': '军事',\n                           '112': '旅游',\n                           '113': '世界',\n                           '114': '股票',\n                           '115': '农业',\n                           '116': '游戏'}\n\n        dataset = self.load_dataset(eval_filename)\n        if train_filename is not None:\n            self.labelled_list = self.load_dataset(eval_filename)\n        self.post_list = dataset\n\n        for k in range(5):\n            print_rank_0(f\"TNEWSDataset sample-{k}\\n: {dataset[k]}\")\n\n    def __len__(self):\n        return len(self.post_list)\n\n    def __getitem__(self, idx):\n        data = self.post_list[idx]\n        prompt = data['prompt']\n        label = data['label']\n        candidates = data['candidates']\n\n        # Few-Shot example construction\n        if hasattr(self, \"labelled_list\"):\n            examples = random.sample(self.labelled_list, min(len(self.labelled_list), self.args.max_few_shot))\n            prompts = []\n            prompt_tokens = self.tokenizer.tokenize(prompt)\n            for example in examples:\n                example_prompt = example['prompt']\n                exmample_tokens = self.tokenizer.tokenize(example_prompt + \"\\n\")\n                if len(exmample_tokens) + len(prompt_tokens) + 2 > self.args.max_length:\n                    break\n                else:\n                    prompts.append(example_prompt)\n                    prompt_tokens.extend(exmample_tokens)\n            prompts.append(prompt)\n            prompt = \"\\n\".join(prompts)\n\n        encoded_dict = self.tokenizer(prompt, max_length=self.args.max_length,\n                                      padding=\"max_length\", truncation=\"longest_first\", return_tensors=\"pt\")\n        # label_dict = self.tokenizer(label, max_length=self.args.max_length, add_special_tokens=False,\n        #                             return_attention_mask=False, return_token_type_ids=False, return_tensors=\"pt\")\n\n        return {\n            \"input_ids\": encoded_dict[\"input_ids\"],\n            \"attention_mask\": encoded_dict[\"attention_mask\"],\n            \"labels\": encoded_dict[\"input_ids\"],\n            \"label_str\": label,\n            \"candidates\": candidates\n        }\n\n    def load_dataset(self, filename):\n        discard = 0\n        datasets = []\n        with open(filename, \"r\", encoding=\"utf-8\") as f:\n            for i, line in tqdm(enumerate(f), desc=f\"Loading {os.path.basename(filename)}\"):\n                item = json.loads(line)\n                content = item['sentence']\n                label = self.label_dict[item['label']]\n                # randomly sample 3 categories as negative sample\n                labels = set(self.label_dict.values())\n                labels.remove(label)\n                candidates = [label] + random.sample(labels, 3)\n                for l in candidates:\n                    prompt = f'这是关于{l}的文章:{content}'\n                    if len(prompt) <= 0:\n                        continue\n                    datasets.append({\"prompt\": prompt, \"label\": label, \"candidates\": candidates})\n\n        print_rank_0(f\"Finished loading {os.path.basename(filename)}, # discarded: {discard}\")\n\n        return datasets\n\n\nclass CEvalDataset(Dataset):\n    def __init__(self, args, eval_filename, tokenizer, train_filename=None):\n        self.tokenizer = tokenizer\n        self.args = args\n        self.model_name_or_path = args.model_name_or_path if hasattr(args,\n                                                                     \"model_name_or_path\") else args.actor_model_path\n        self.subject_mapping = json.load(open(os.path.join(RESOURCE_PATH, \"eval\", \"ceval\", \"subject_mapping.json\")))\n        self.max_length = args.max_length - args.max_length_generation\n        self.choices = [\"A\", \"B\", \"C\", \"D\"]\n\n        self.post_list = self.load_dataset(eval_filename)\n        if train_filename is not None:\n            self.dev_list = self.load_dataset(train_filename, \"dict\")\n\n        for k in range(5):\n            print_rank_0(f\"CEvalDataset sample-{k}\\n: {self.post_list[k]}\")\n\n    def __len__(self):\n        return len(self.post_list)\n\n    def format_example(self, line, include_answer=True, cot=False):\n        example = line['question']\n        for choice in self.choices:\n            example += f'\\n{choice}. {line[f\"{choice}\"]}'\n        example += '\\n答案：'\n        if \"chatglm\" in self.model_name_or_path.lower() or \"qwen\" in self.model_name_or_path.lower():\n            if include_answer:\n                if cot:\n                    ans = \"让我们一步一步思考，\\n\" + line[\"explanation\"] + f\"\\n所以答案是{line['answer']}。\"\n                else:\n                    ans = line[\"answer\"]\n                m = (example, ans)\n                return m\n            return example\n        else:\n            # example = line['question']\n            # for choice in self.choices:\n            #     example += f'\\n{choice}. {line[f\"{choice}\"]}'\n            if include_answer:\n                if cot:\n                    example += \"让我们一步一步思考，\\n\" + line[\"explanation\"] + f\"\\n所以答案是{line['answer']}。\"\n                else:\n                    example += line[\"answer\"]\n            else:\n                if cot:\n                    example += \"让我们一步一步思考，\\n1.\"\n            return example\n\n    def __getitem__(self, idx):\n        data = self.post_list[idx]\n        subject_name = data['subject_name']\n        question = self.format_example(data, include_answer=False, cot=self.args.cot)\n        prefix = f\"以下是中国关于{subject_name}考试的单项选择题，请选出其中的正确答案。\"\n\n        history = []\n        if \"chatglm\" in self.model_name_or_path.lower():\n            sep = \"\\n\\n\" if \"chatglm2\" in self.model_name_or_path.lower() else \"\\n\"\n            offset = 1 if \"chatglm2\" in self.model_name_or_path.lower() else 0\n            # Few-Shot example construction\n            if hasattr(self, \"dev_list\"):\n                history.append(prefix)\n                k = self.args.max_few_shot\n                dev_list = self.dev_list[subject_name]\n                for i in range(min(k, len(dev_list))):\n                    prompt, answer = self.format_example(dev_list[i], include_answer=True, cot=self.args.cot)\n                    prompt = f\"[Round {i + offset}]{sep}问：{prompt}{sep}答：{answer}\"\n                    history.append(prompt)\n            # Concat few-shot/zero-shot examples with question.\n            # If length of full prompt exceeds max_length, remove examples until the length is smaller than max_length\n            question = f\"[Round {len(history) + offset}]{sep}问：{question}{sep}答：\"\n            while True:\n                full_prompt = sep.join(history + [question])\n                input_ids = self.tokenizer.encode(full_prompt)\n                if len(input_ids) <= self.max_length:\n                    break\n                elif len(history) <= 1:\n                    full_prompt = question\n                    break\n                else:\n                    history.pop(-1)\n            encoded_dict = self.tokenizer(full_prompt, max_length=self.max_length, return_tensors=\"pt\",\n                                          truncation=\"longest_first\")\n        elif \"qwen\" in self.model_name_or_path.lower():\n            # Few-Shot example construction\n            if hasattr(self, \"dev_list\"):\n                k = self.args.max_few_shot\n                dev_list = self.dev_list[subject_name]\n                for i in range(min(k, len(dev_list))):\n                    history.append(self.format_example(dev_list[i], include_answer=True, cot=self.args.cot))\n            full_prompt, input_ids = qwen_make_context(self.tokenizer, question, history, system=prefix,\n                                                       max_window_size=self.max_length)\n            encoded_dict = {\"input_ids\": torch.tensor(input_ids, dtype=torch.int64)}\n        else:\n            # Few-Shot example construction\n            if hasattr(self, \"dev_list\"):\n                history.append(prefix)\n                k = self.args.max_few_shot\n                dev_list = self.dev_list[subject_name]\n                for i in range(min(k, len(dev_list))):\n                    history.append(self.format_example(dev_list[i], include_answer=True, cot=self.args.cot))\n            # Concat few-shot/zero-shot examples with question.\n            # If length of full prompt exceeds max_length, remove examples until the length is smaller than max_length\n            while True:\n                full_prompt = \"\\n\\n\".join(history + [question])\n                input_ids = self.tokenizer.encode(full_prompt)\n                if len(input_ids) <= self.max_length:\n                    break\n                elif len(history) <= 1:\n                    full_prompt = question\n                    break\n                else:\n                    history.pop(-1)\n            encoded_dict = self.tokenizer(full_prompt, max_length=self.max_length, return_tensors=\"pt\",\n                                          truncation=\"longest_first\")\n\n        logger.debug(f\"number of shots: {len(history) - 1}, full prompt: {full_prompt}\")\n\n        return {\n            \"input_ids\": encoded_dict[\"input_ids\"],\n            \"attention_mask\": encoded_dict.get(\"attention_mask\", None),\n            \"number_of_shots\": max(len(history) - 1, 0),\n            \"id\": data['id'],\n            \"subject_name_key\": data['subject_name_key'],\n            \"answer\": data.get('answer', None)\n        }\n\n    def load_dataset(self, filename, return_format=\"list\"):\n        datasets = list() if return_format == \"list\" else dict()\n        dt = os.path.basename(filename)\n\n        for subject_name_key, subject in self.subject_mapping.items():\n            subject_name = subject[1]\n            if isinstance(datasets, dict):\n                datasets[subject_name] = list()\n            dev_file_path = os.path.join(filename, f'{subject_name_key}_{dt}.csv')\n            dev_df = pd.read_csv(dev_file_path)\n            for i, val in dev_df.iterrows():\n                d = val.to_dict()\n                if isinstance(datasets, dict):\n                    datasets[subject_name].append(d)\n                else:\n                    d['subject_name'] = subject_name\n                    d['subject_name_key'] = subject_name_key\n                    datasets.append(d)\n\n        print_rank_0(f\"Finished loading {dt} dataset\")\n\n        return datasets\n\n\nclass MMLUDataset(Dataset):\n    def __init__(self, args, eval_filename, tokenizer, train_filename=None):\n        self.tokenizer = tokenizer\n        self.args = args\n        self.model_name_or_path = args.model_name_or_path if hasattr(args,\n                                                                     \"model_name_or_path\") else args.actor_model_path\n        self.subject_mapping = json.load(open(os.path.join(RESOURCE_PATH, \"eval\", \"mmlu\", \"subject_mapping.json\")))\n        self.choices = [\"A\", \"B\", \"C\", \"D\"]\n        self.max_length = args.max_length - args.max_length_generation\n\n        self.post_list = self.load_dataset(eval_filename)\n        if train_filename is not None:\n            self.dev_list = self.load_dataset(train_filename, \"dict\")\n\n        for k in range(5):\n            print_rank_0(f\"MMLUDataset sample-{k}\\n: {self.post_list[k]}\")\n\n    def __len__(self):\n        return len(self.post_list)\n\n    def format_example(self, line, include_answer=True):\n        example = line['question']\n        for choice in self.choices:\n            example += f'\\n{choice}. {line[f\"{choice}\"]}'\n        example += '\\nAnswer：'\n        if \"chatglm\" in self.model_name_or_path.lower() or \"qwen\" in self.model_name_or_path.lower():\n            if include_answer:\n                ans = line[\"answer\"]\n                m = (example, ans)\n                return m\n            return example\n        else:\n            # example = line['question']\n            # for choice in self.choices:\n            #     example += f'\\n{choice}. {line[f\"{choice}\"]}'\n            if include_answer:\n                example += line[\"answer\"]\n            return example\n\n    def __getitem__(self, idx):\n        data = self.post_list[idx]\n        subject_name = data['subject_name']\n        question = self.format_example(data, include_answer=False)\n        prefix = f\"The following are multiple choice questions (with answers) about {subject_name}.\"\n\n        history = []\n        if \"chatglm\" in self.model_name_or_path.lower():\n            sep = \"\\n\\n\" if \"chatglm2\" in self.model_name_or_path.lower() else \"\\n\"\n            offset = 1 if \"chatglm2\" in self.model_name_or_path.lower() else 0\n            # Few-Shot example construction\n            if hasattr(self, \"dev_list\"):\n                history.append(prefix)\n                k = self.args.max_few_shot\n                dev_list = self.dev_list[subject_name]\n                for i in range(min(k, len(dev_list))):\n                    prompt, answer = self.format_example(dev_list[i], include_answer=True)\n                    prompt = f\"[Round {i + offset}]{sep}问：{prompt}{sep}答：{answer}\"\n                    history.append(prompt)\n            # Concat few-shot/zero-shot examples with question.\n            # If length of full prompt exceeds max_length, remove examples until the length is smaller than max_length\n            question = f\"[Round {len(history) + offset}]{sep}问：{question}{sep}答：\"\n            while True:\n                full_prompt = sep.join(history + [question])\n                input_ids = self.tokenizer.encode(full_prompt)\n                if len(input_ids) <= self.max_length:\n                    break\n                elif len(history) <= 1:\n                    full_prompt = question\n                    break\n                else:\n                    history.pop(-1)\n\n            encoded_dict = self.tokenizer(full_prompt, max_length=self.max_length, return_tensors=\"pt\",\n                                          truncation=\"longest_first\")\n        elif \"qwen\" in self.model_name_or_path.lower():\n            # Few-Shot example construction\n            if hasattr(self, \"dev_list\"):\n                k = self.args.max_few_shot\n                dev_list = self.dev_list[subject_name]\n                for i in range(min(k, len(dev_list))):\n                    history.append(self.format_example(dev_list[i], include_answer=True))\n            full_prompt, input_ids = qwen_make_context(self.tokenizer, question, history, system=prefix,\n                                                       max_window_size=self.max_length)\n            encoded_dict = {\"input_ids\": torch.tensor(input_ids, dtype=torch.int64)}\n        else:\n            # Few-Shot example construction\n            if hasattr(self, \"dev_list\"):\n                history.append(prefix)\n                k = self.args.max_few_shot\n                dev_list = self.dev_list[subject_name]\n                for i in range(min(k, len(dev_list))):\n                    history.append(self.format_example(dev_list[i], include_answer=True))\n            # Concat few-shot/zero-shot examples with question.\n            # If length of full prompt exceeds max_length, remove examples until the length is smaller than max_length\n            while True:\n                full_prompt = \"\\n\\n\".join(history + [question])\n                input_ids = self.tokenizer.encode(full_prompt)\n                if len(input_ids) <= self.max_length:\n                    break\n                elif len(history) <= 1:\n                    full_prompt = question\n                    break\n                else:\n                    history.pop(-1)\n\n            encoded_dict = self.tokenizer(full_prompt, max_length=self.max_length, return_tensors=\"pt\",\n                                          truncation=\"longest_first\")\n\n        return {\n            \"input_ids\": encoded_dict[\"input_ids\"],\n            \"attention_mask\": encoded_dict.get(\"attention_mask\", None),\n            \"number_of_shots\": max(len(history) - 1, 0),\n            \"subject_name_key\": data['subject_name_key'],\n            \"answer\": data.get('answer', None)\n        }\n\n    def load_dataset(self, filename, return_format=\"list\"):\n        datasets = list() if return_format == \"list\" else dict()\n        dt = os.path.basename(filename)\n\n        for subject_name_key, subject in self.subject_mapping.items():\n            subject_name = subject[0]\n            if isinstance(datasets, dict):\n                datasets[subject_name] = list()\n            dev_file_path = os.path.join(filename, f'{subject_name_key}_{dt}.csv')\n            dev_df = pd.read_csv(dev_file_path, names=[\"question\", \"A\", \"B\", \"C\", \"D\", \"answer\"])\n            for i, val in dev_df.iterrows():\n                d = val.to_dict()\n                if isinstance(datasets, dict):\n                    datasets[subject_name].append(d)\n                else:\n                    d['subject_name'] = subject_name\n                    d['subject_name_key'] = subject_name_key\n                    datasets.append(d)\n\n        print_rank_0(f\"Finished loading {dt} dataset\")\n\n        return datasets\n"
  },
  {
    "path": "src/data/data_types.py",
    "content": "\nfrom dataclasses import dataclass\nfrom typing import Iterable\n\nfrom torchtyping import TensorType\n\n\n@dataclass\nclass PromptElement:\n    \"\"\"\n    Dataclass for a single prompt, containing its string and tokenized form.\n\n    :param text: The prompt text.\n    :type text: str\n\n    :param tokens: The prompt tokens. Should be a long tensor\n    :type tokens: torch.Tensor\n    \"\"\"\n\n    text: str\n    tokens: TensorType[\"num_tokens\"]\n\n\n@dataclass\nclass PromptBatch:\n    \"\"\"\n    Batched PromptElement\n\n    :param text: An iterable of prompt texts.\n    :type text: Iterable[str]\n\n    :param tokens: A long tensor batch of prompt tokens.\n    :type tokens: torch.Tensor\n    \"\"\"\n\n    text: Iterable[str]\n    tokens: TensorType[\"batch_size\", \"num_tokens\"]\n\n\n@dataclass\nclass AccelerateRLElement:\n    \"\"\"\n    Dataclass for RL elements, containing output tokens and rewards for each token.\n\n    :param tokens: The output tokens. Should be a long tensor\n    :type tokens: torch.Tensor\n\n    :param rewards: The rewards for each token. Should be a float tensor of same size as tokens.\n    :type rewards: torch.Tensor\n    \"\"\"\n\n    output_tokens: TensorType[\"output_size\"]\n    rewards: TensorType[\"output_size\"]\n\n\n@dataclass\nclass AccelerateRLBatchElement:\n    \"\"\"\n    Batched accelerate RL element\n\n    :param tokens: Batches of long tensors of output tokens.\n    :type tokens: torch.Tensor\n\n    :param rewards: Batches of float tensors of rewards for each output token.\n    :type rewards: torch.Tensor\n    \"\"\"\n\n    output_tokens: TensorType[\"batch_size\", \"output_size\"]\n    rewards: TensorType[\"batch_size\", \"output_size\"]\n\n\n@dataclass\nclass PPORLElement:\n    \"\"\"\n    :param query_tensor: The query tensor i.e. the prompt tokens.\n                         Should be a long tensor.\n    :type query_tensor: torch.Tensor\n\n    :param response_tensor: The response tensor i.e. the output tokens.\n                            Should be a long tensor.\n    :type response_tensor: torch.Tensor\n\n    :param logprobs: The log probabilities over the response tokens generated\n                    by the policy network (i.e. the autoregressive model).\n                    Should be a float tensor of same size as tokens.\n    :type logprobs: torch.Tensor\n\n    :param values: The values for each token generated from the value network or value head.\n                    Should be a float tensor of same size as tokens.\n    :type values: torch.Tensor\n\n    :param rewards: The rewards for each token outputted in response.\n                    Should be a float tensor of same size as tokens.\n    :type rewards: torch.Tensor\n    \"\"\"\n\n    query_tensor: TensorType[\"query_size\"]\n    response_tensor: TensorType[\"response_size\"]\n    attention_mask: TensorType[\"query_size\"]\n    position_ids: TensorType[\"query_size\"]\n    logprobs: TensorType[\"response_size\"]\n    values: TensorType[\"response_size\"]\n    rewards: TensorType[\"response_size\"]\n\n\n@dataclass\nclass PPORLBatch:\n    \"\"\"\n    A batched version of the PPORLElement. See PPORLElement for more details on individual fields.\n\n    :param query_tensors: A batch of query tensors. Should be a long tensor.\n    :type query_tensors: torch.Tensor\n\n    :param response_tensors: A batch of response tensors. Should be a long tensor.\n    :type response_tensors: torch.Tensor\n\n    :param logprobs: A batch of log probabilities from policy\n    :type logprobs: torch.Tensor\n\n    :param values: A batch of values from value network\n    :type values: torch.Tensor\n\n    :param rewards: A batch of rewards\n    :type rewards: torch.Tensor\n    \"\"\"\n\n    query_tensors: TensorType[\"batch_size\", \"query_size\"]\n    response_tensors: TensorType[\"batch_size\", \"response_size\"]\n    attention_mask: TensorType[\"batch_size\", \"query_size\"]\n    position_ids: TensorType[\"batch_size\", \"query_size\"]\n    logprobs: TensorType[\"batch_size\", \"response_size\"]\n    values: TensorType[\"batch_size\", \"response_size\"]\n    rewards: TensorType[\"batch_size\", \"response_size\"]"
  },
  {
    "path": "src/data/pipeline.py",
    "content": "\nimport os\nimport sys\nimport time\nimport json\nimport torch\n\nfrom abc import abstractmethod\nfrom typing import Optional, Any, Callable, Dict, Iterable, List, Union\nfrom dataclasses import dataclass\nfrom torch.utils.data import DataLoader, Dataset\nfrom torch.nn.utils.rnn import pad_sequence\nfrom torchtyping import TensorType\nfrom transformers import PreTrainedTokenizer, PreTrainedTokenizerBase\nfrom transformers.tokenization_utils_base import PaddingStrategy\n\nfrom src.data.data_types import PPORLBatch, PPORLElement\nfrom src.utils.config import TRLConfig\n\n\n# specifies a dictionary of architectures\n_DATAPIPELINE: Dict[str, any] = {}  # registry\n\n@dataclass\nclass GeneralElement:\n    \"\"\"\n    General element outputted by a data pipeline\n    \"\"\"\n\n    pass\n\n\n@dataclass\nclass RLElement:\n    \"\"\"\n    Batch element for RL model\n    \"\"\"\n\n    state: Iterable[str] = None  # Context/prompts\n    action: TensorType[\"N\"] = None  # Tokens generated by model given prompts\n    reward: float = None  # Reward obtained for that generation\n\n\n@dataclass\nclass BatchElement:\n    \"\"\"\n    General batch element for any transformer to use in its forward pass\n    \"\"\"\n\n    tokens: TensorType[\"BATCH\", \"SEQ_LEN\"]\n    masks: TensorType[\"BATCH\", \"SEQ_LEN\"]\n\n\n@dataclass\nclass GLMDataCollator:\n\n    tokenizer: PreTrainedTokenizerBase\n    padding: Union[bool, str, PaddingStrategy] = True\n    max_length: Optional[int] = None\n    pad_to_multiple_of: Optional[int] = None\n    return_tensors: str = \"pt\"\n\n    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:\n        # batch = self.tokenizer.pad(\n        #     features,\n        #     padding=self.padding,\n        #     max_length=self.max_length,\n        #     pad_to_multiple_of=self.pad_to_multiple_of,\n        #     return_tensors=self.return_tensors,\n        # )\n        max_length = max(map(lambda x: x['input_ids'].shape[0], features))\n        input_ids_list = []\n        attention_mask_list = []\n        position_ids_list = []\n        labels_list = []\n        for feature in features:\n            input_ids = feature['input_ids']\n            seq_length = input_ids.shape[0]\n            # padding for GLM generation: cls_token_id + prompt_tokens + mask_token_id + [eos_token_id]*N + sop_token_id\n            input_ids = torch.cat((input_ids[:-1],\n                                   torch.tensor([self.tokenizer.pad_token_id]*(max_length-seq_length), dtype=input_ids.dtype),\n                                   input_ids[-1:]\n                                   ), dim=0)\n            input_ids_list.append(input_ids)\n            attention_mask_list.append(feature['generation_attention_mask'])\n            position_ids_list.append(feature['position_ids'])\n            if \"labels\" in feature:\n                labels_list.append(feature['labels'])\n\n        batch = {\n            \"input_ids\": torch.stack(input_ids_list, dim=0),\n            \"generation_attention_mask\": torch.stack(attention_mask_list, dim=0),\n            \"position_ids\": torch.stack(position_ids_list, dim=0)\n        }\n\n        if len(labels_list) > 0:\n            batch['labels'] = torch.stack(labels_list, dim=0)\n\n        return batch\n\n\ndef register_datapipeline(name):\n    \"\"\"Decorator used register a CARP architecture\n    Args:\n        name: Name of the architecture\n    \"\"\"\n\n    def register_class(cls, name):\n        _DATAPIPELINE[name] = cls\n        setattr(sys.modules[__name__], name, cls)\n        return cls\n\n    if isinstance(name, str):\n        name = name.lower()\n        return lambda c: register_class(c, name)\n\n    cls = name\n    name = cls.__name__\n    register_class(cls, name.lower())\n\n    return cls\n\n\n@register_datapipeline\nclass BasePipeline(Dataset):\n    def __init__(self, path: str = \"dataset\"):\n        super().__init__()\n\n    @abstractmethod\n    def __getitem__(self, index: int) -> GeneralElement:\n        pass\n\n    @abstractmethod\n    def __len__(self) -> int:\n        pass\n\n    @abstractmethod\n    def create_loader(\n            self,\n            batch_size: int,\n            shuffle: bool,\n            prep_fn: Callable = None,\n            num_workers: int = 0,\n    ) -> DataLoader:\n        \"\"\"\n        Create a dataloader for the pipeline\n\n        :param prep_fn: Typically a tokenizer. Applied to GeneralElement after collation.\n        \"\"\"\n        pass\n\n\nclass BaseRolloutStore(Dataset):\n    def __init__(self, capacity=-1):\n        self.history: Iterable[Any] = None\n        self.capacity = capacity\n\n    @abstractmethod\n    def push(self, exps: Iterable[Any]):\n        \"\"\"\n        Push experiences to rollout storage\n        \"\"\"\n        pass\n\n    def __getitem__(self, index: int) -> RLElement:\n        return self.history[index]\n\n    def __len__(self) -> int:\n        return len(self.history)\n\n    @abstractmethod\n    def create_loader(\n            self,\n            batch_size: int,\n            shuffle: bool,\n            prep_fn: Callable = None,\n            num_workers: int = 0,\n    ) -> DataLoader:\n        \"\"\"\n        Create a dataloader for the rollout store\n\n        :param prep_fn: Applied to RLElement after collation (typically tokenizer)\n        :type prep_fn: Callable\n        \"\"\"\n        pass\n\n\n@register_datapipeline\nclass PanguPipeline(BasePipeline):\n    def __init__(self, prompts: List[dict], config: TRLConfig, tokenizer: PreTrainedTokenizer):\n\n        super().__init__()\n\n        self.prompts = prompts\n        self.tokenizer = tokenizer\n        self.config = config\n        self.max_prompt_length = config.train.seq_length - config.method.gen_kwargs[\"max_new_tokens\"]\n\n    def __len__(self):\n        return len(self.prompts)\n\n    def __getitem__(self, idx):\n        data = self.prompts[idx]\n        prompt = data['prompt']\n        prefix = data['prefix']\n        encoded_dict = self.tokenizer(prompt, self.tokenizer.sep_token + prefix,\n                                      max_length=self.max_prompt_length,\n                                      return_tensors=\"pt\",\n                                      truncation=\"only_first\",\n                                      # padding=\"max_length\",\n                                      add_special_tokens=False,\n                                      return_token_type_ids=False)\n\n        return {\n            \"input_ids\": encoded_dict['input_ids'][0],\n            \"attention_mask\": encoded_dict['attention_mask'][0],\n        }\n\n    def create_loader(self, batch_size: int, shuffle=False) -> DataLoader:\n        return DataLoader(self, batch_size=batch_size, shuffle=shuffle)\n\n\n@register_datapipeline\nclass GLMPipeline(BasePipeline):\n    def __init__(self, prompts: List[dict], config: TRLConfig, tokenizer: PreTrainedTokenizer):\n\n        super().__init__()\n\n        self.prompts = prompts\n        self.tokenizer = tokenizer\n        # self.config = config\n        self.max_generation_length = config.method.gen_kwargs[\"max_new_tokens\"]\n        self.max_prompt_length = config.train.seq_length - self.max_generation_length\n\n    def __len__(self):\n        return len(self.prompts)\n\n    def __getitem__(self, idx):\n        data = self.prompts[idx]\n        prompt = data['prompt']\n        prefix = data['prefix']\n\n        inputs = self.tokenizer(prompt, prefix + self.tokenizer.mask_token,\n                                max_length=self.max_prompt_length,\n                                truncation=\"only_first\",\n                                # padding=\"max_length\",\n                                return_tensors=\"pt\",\n                                return_token_type_ids=False)\n        inputs_glm = self.tokenizer.build_inputs_for_generation(inputs, max_gen_length=self.max_generation_length,\n                                                                padding=True)\n        return {\n            \"input_ids\": inputs_glm['input_ids'][0],\n            \"position_ids\": inputs_glm['position_ids'][0],\n            \"generation_attention_mask\": inputs_glm['generation_attention_mask'][0]\n        }\n\n    def create_loader(self, batch_size: int, shuffle=False) -> DataLoader:\n        # collate_fn = GLMDataCollator(self.tokenizer)\n        return DataLoader(self, batch_size=batch_size, shuffle=shuffle)#, collate_fn=collate_fn)\n\n\n@register_datapipeline\nclass ChatGLMPipeline(BasePipeline):\n    def __init__(self, prompts: List[dict], config: TRLConfig, tokenizer: PreTrainedTokenizer):\n\n        super().__init__()\n\n        self.prompts = prompts\n        self.tokenizer = tokenizer\n        self.config = config\n        self.max_prompt_length = config.train.seq_length - config.method.gen_kwargs[\"max_new_tokens\"]\n\n    def __len__(self):\n        return len(self.prompts)\n\n    def __getitem__(self, idx):\n        data = self.prompts[idx]\n        prompt = data['prompt']\n        encoded_dict = self.tokenizer(prompt, max_length=self.max_prompt_length, return_tensors=\"pt\",\n                                      truncation=\"only_first\", padding=\"max_length\")\n\n        return {\n            \"input_ids\": encoded_dict['input_ids'][0],\n            # \"attention_mask\": encoded_dict['attention_mask'][0],\n        }\n\n    def create_loader(self, batch_size: int, shuffle=False) -> DataLoader:\n        return DataLoader(self, batch_size=batch_size, shuffle=shuffle)\n\n\nclass PPORolloutStorage(BaseRolloutStore):\n    \"\"\"\n    Rollout storage for training PPO\n    \"\"\"\n\n    def __init__(self, pad_token_id):\n        super().__init__()\n\n        self.pad_token_id = pad_token_id\n        self.history: Iterable[PPORLElement] = [None]\n\n    def push(self, exps: Iterable[PPORLElement]):\n        self.history += exps\n\n    def clear_history(self):\n        self.history = []\n\n    def export_history(self, location: str):\n        assert os.path.exists(location)\n\n        fpath = os.path.join(location, f\"epoch-{str(time.time())}.json\")\n\n        def exp_to_dict(exp):\n            {k: v.cpu().tolist() for k, v in exp.__dict__.items()}\n\n        data = [exp_to_dict(exp) for exp in self.history]\n        with open(fpath, \"w\") as f:\n            f.write(json.dumps(data, indent=2))\n\n    def __getitem__(self, index: int) -> PPORLElement:\n        return self.history[index]\n\n    def __len__(self) -> int:\n        return len(self.history)\n\n    def create_loader(\n            self,\n            batch_size: int,\n            shuffle: bool,\n    ) -> DataLoader:\n        def collate_fn(elems: Iterable[PPORLElement]):\n            return PPORLBatch(\n                torch.stack([elem.query_tensor for elem in elems]),\n                # # Left padding of already left-padded queries\n                # pad_sequence(\n                #     [elem.query_tensor.flip(0) for elem in elems],\n                #     padding_value=self.pad_token_id,\n                #     batch_first=True,\n                # ).flip(1),\n                # Right pad the rest, to have a single horizontal query/response split\n                pad_sequence(\n                    [elem.response_tensor for elem in elems],\n                    padding_value=self.pad_token_id,\n                    batch_first=True,\n                ),\n                torch.stack([elem.attention_mask for elem in elems]),\n                [elem.position_ids for elem in elems],\n                pad_sequence(\n                    [elem.logprobs for elem in elems],\n                    padding_value=0.0,\n                    batch_first=True,\n                ),\n                pad_sequence([elem.values for elem in elems], padding_value=0.0, batch_first=True),\n                pad_sequence(\n                    [elem.rewards for elem in elems],\n                    padding_value=0.0,\n                    batch_first=True,\n                ),\n            )\n\n        return DataLoader(self, batch_size, shuffle=shuffle, collate_fn=collate_fn)"
  },
  {
    "path": "src/data_prepare.py",
    "content": "\nimport sys\nsys.path.insert(0, \"/root/autodl-tmp/Code/RLHF\")\nsys.path.insert(0, \"/mnt/sfevol775196/sunzeye273/Code/chatgpt\")\n# sys.path.insert(0, \"/mnt/share-pa002-vol682688-prd/sunzeye273/Code/chatgpt\")\nsys.path.insert(0, \"/mnt/pa002-28359-vol543625-private/Code/chatgpt\")\nimport os\nimport time\nimport json\nimport argparse\nimport random\nimport glob\nimport pandas as pd\nimport opencc\n\nfrom tqdm import tqdm\nfrom transformers import AutoTokenizer\n\nfrom src.utils.logger import logger\nfrom src.utils.file_utils import set_seed\nfrom src.utils.nlp_utils import clean_text\n\n\ndef weibo_summary_comment(args, tokenizer):\n    ''' 微博新闻+摘要+评论，每条评论有点赞数，可以根据点赞数构造reward训练集\n\n    :param args:\n    :param tokenizer:\n\n    :return: processed json list\n    '''\n    t = time.time()\n    fi = os.path.join(args.data_dir, \"weibo_summary_comments_json.json\")\n    fo = os.path.join(args.output_dir, \"weibo_summary_comments.jsonl\")\n    data = []\n    with open(fo, \"w\", encoding=\"utf-8\") as w:\n        with open(fi, \"r\", encoding=\"utf-8\") as r:\n            while True:\n                line = r.readline()\n                if not line:\n                    break\n\n                item = json.loads(line.strip(\"\\n\"))\n                article = item['article'].replace(\" \", \"\")\n                abstract = item['abstract'].replace(\" \", \"\")\n                prompt = f\"新闻内容：{article} 摘要：{abstract}\"\n                prefix = \"评论：\"\n                answers = [\n                    {\n                        \"answer\": k.replace(\" \", \"\"),\n                        \"score\": int(v)\n                    } for (k, v) in sorted(item['comments'], key=lambda x: (int(x[1]), len(x[0])), reverse=True)\n                ]\n                w.write(json.dumps({\"prompt\": prompt, \"answers\": answers}, ensure_ascii=False)+'\\n')\n                data.append({\"prompt\": prompt, \"answers\": answers, \"prefix\": prefix})\n    logger.info(f\"length: {len(data)}, time taken: {time.time()-t} s\")\n\n    return data\n\n\ndef couplets(args, tokenizer):\n    ''' 对联数据（上联和下联），可以根据正确下联和负例下联，构造reward训练集\n\n    :param args:\n    :param tokenizer:\n\n    :return: processed json list\n    '''\n    t1 = time.time()\n    fi = os.path.join(args.data_dir, \"couplets.txt\")\n    fo = os.path.join(args.output_dir, \"couplets.jsonl\")\n    l2 = []\n    nexts = dict()\n    with open(fi, \"r\", encoding=\"utf-8\") as r:\n        while True:\n            line = r.readline()\n            if not line:\n                break\n            line = line.strip(\"\\n\")\n            idx = len(line) // 2\n            prompt = line[:idx]\n            answer = line[idx+1:]\n            prefix = \"下联：\"\n            answers = [{\"answer\": answer, \"score\": 1}]\n            l2.append({\"prompt\": f\"上联：{prompt}\", \"answers\": answers, \"prefix\": prefix})\n            length = len(answer)\n            if length not in nexts:\n                nexts[length] = list()\n            nexts[length].append(answer)\n    t2 = time.time()\n    logger.info(f\"length: {len(l2)}, # different lengths: {len(nexts)}, time taken: {t2-t1} s\")\n    data = []\n    with open(fo, \"w\", encoding=\"utf-8\") as w:\n        for l in tqdm(enumerate(l2), desc=\"Processing Couplets\"):\n            answer = l['answers'][0]\n            length = len(answer['answer'])\n            # 上下联长度一样\n            nexts_tmp = set(nexts[length])\n            nexts_tmp.remove(answer['answer'])\n            nexts_tmp = set(nexts[length]).difference(set([answer['answer']]))\n            #         nexts_tmp.remove(answer['answer'])\n            answers.extend([{\"answer\": fa, \"score\": 0} for fa in random.sample(nexts_tmp, 2)])\n            # 上下联长度不一样\n            keys = set(nexts.keys())\n            keys.remove(length)\n            answers.extend([{\"answer\": random.choice(nexts[key]), \"score\": -1} for key in random.sample(keys, 2)])\n            #         answers = sorted(answers, key=lambda x: x['score'], reverse=True)\n            w.write(json.dumps({\"prompt\": l['prompt'], \"answers\": answers, \"prefix\": l['prefix']}, ensure_ascii=False)+'\\n')\n            data.append({\"prompt\": l['prompt'], \"answers\": answers, \"prefix\": l['prefix']})\n    #         if i % 1000 == 0:\n    #             logger.info(f\"{i} samples processed, time taken: {time.time()-t2} s\")\n    logger.info(f\"length: {len(data)}, time taken: {time.time()-t2} s\")\n\n    return data\n\n\ndef zhidao(args, tokenizer):\n    ''' 百度知道的问答数据，每条问题有多个答案以及最佳答案，可以直接构造reward训练集\n\n    :param args:\n    :param tokenizer:\n\n    :return: processed json list\n    '''\n    t = time.time()\n    fp = os.path.join(args.data_dir, \"zhidao\", \"*.csv\")\n    fo = os.path.join(args.output_dir, \"zhidao.jsonl\")\n    data = []\n    with open(fo, \"w\", encoding=\"utf-8\") as w:\n        for fi in glob.glob(fp):\n            df = pd.read_csv(fi).sort_values(by=[\"title\", \"is_best\"], ascending=False)\n            prev_title = None\n            prev_prompt = None\n            prefix = \"答：\"\n            for _, val in df.iterrows():\n                if isinstance(val['question'], str) and val['question'] != val['title']:\n                    prompt = f\"问题：{val['title']} 内容：{val['question']}\"\n                else:\n                    prompt = f\"问题：{val['title']}\"\n                if prev_title is not None and prev_title == val['title']:\n                    answers.append({\"answer\": val['reply'], \"score\": val['is_best']})\n                else:\n                    if prev_title is not None:\n                        #                     l3.append({\"prompt\": prev_prompt, \"answers\": copy.deepcopy(answers)})\n                        w.write(json.dumps({\"prompt\": prev_prompt, \"answers\": answers}, ensure_ascii=False)+'\\n')\n                        data.append({\"prompt\": prev_prompt, \"answers\": answers})\n                    answers = [{\"answer\": val['reply'], \"score\": val['is_best']}]\n                prev_prompt = prompt\n                prev_title = val['title']\n            #         l3.append({\"prompt\": prev_prompt, \"answers\": copy.deepcopy(answers)})\n            w.write(json.dumps({\"prompt\": prev_prompt, \"answers\": answers, \"prefix\": prefix}, ensure_ascii=False)+'\\n')\n            data.append({\"prompt\": prev_prompt, \"answers\": answers, \"prefix\": prefix})\n            logger.info(f\"finished processing {os.path.basename(fi)}\")\n    logger.info(f\"length: {len(data)}, time taken: {time.time()-t} s\")\n\n    return data\n\n\ndef chinese_classical(args, tokenizer):\n    ''' 文言文和现代文的对照翻译，每条文言文有对应的现代文翻译，可以根据正确翻译或原文+负例翻译或原文，构造reward训练集\n\n    :param args:\n    :param tokenizer:\n\n    :return: processed json list\n    '''\n    t1 = time.time()\n    fp = os.path.join(args.data_dir, \"Classical-Modern\", \"bitext\", \"*\")\n    fo = os.path.join(args.output_dir, \"chinese_classical.jsonl\")\n    l3 = []\n    dicts = dict()\n    for fi in glob.glob(fp):\n        name = os.path.basename(fi)\n        dicts[name] = {\"古文\": [], \"现代文\": []}\n        with open(fi, \"r\", encoding=\"utf-8\") as r:\n            for i, line in enumerate(r):\n                line = line.strip(\"\\n\")\n                if line.startswith(\"古文\"):\n                    p1 = line[3:]\n                    dicts[name]['古文'].append(p1)\n                elif line.startswith(\"现代文\"):\n                    p2 = line[4:]\n                    dicts[name]['现代文'].append(p2)\n                elif p1 is not None and p2 is not None:\n                    pair = [(\"古文\", p1), (\"现代文\", p2)]\n                    random.shuffle(pair)\n                    prompt = f\"{pair[0][0]}：{pair[0][1]}\"\n                    prefix = f\"{pair[1][0]}：\"\n                    answers = [{\"answer\": pair[1][1], \"score\": 1}]\n                    l3.append({\"prompt\": prompt, \"answers\": answers, \"prefix\": prefix, \"name\": name})\n                    p1 = None\n                    p2 = None\n    t2 = time.time()\n    logger.info(f\"length: {len(l3)}, # different names: {len(dicts)}, time taken: {t2-t1} s\")\n    data = []\n    with open(fo, \"w\", encoding=\"utf-8\") as w:\n        for l in tqdm(enumerate(l3), desc=\"Processing Chinese Classical-Modern\"):\n            name = l['name']\n            prompt = l['prompt']\n            prefix = l['prefix']\n            answer = l['answers'][0]['answer']\n            if prompt.startswith(\"古文\"):\n                answer_type = '现代文'\n            else:\n                answer_type = '古文'\n            samples_tmp = set(dicts[name][answer_type])\n            samples_tmp.remove(answer)\n            answers.extend([{\"answer\": fa, \"score\": 0} for fa in random.sample(samples_tmp, 2)])\n            keys = set(dicts.keys())\n            keys.remove(name)\n            answers.extend([{\"answer\": random.choice(dicts[key][answer_type]), \"score\": -1} for key in random.sample(keys, 2)])\n            w.write(json.dumps({\"prompt\": prefix, \"answers\": answers}, ensure_ascii=False)+'\\n')\n            data.append({\"prompt\": prefix, \"answers\": answers})\n    #         if i % 100 == 0:\n    #             logger.info(f\"{i} samples processed, time taken: {time.time()-t2} s\")\n    logger.info(f\"length: {len(data)}, time taken: {time.time()-t2} s\")\n\n    return data\n\n\ndef chinese_poetry(args, tokenizer):\n    ''' 四书五经、诗、词、曲等古文数据，每篇文章或每首诗有作者、题目和正文，可以根据正确体裁/作者+负例体裁/作者，构造reward训练集\n\n    :param args:\n    :param tokenizer:\n\n    :return: processed json list\n    '''\n    converter = opencc.OpenCC('t2s.json')\n    t1 = time.time()\n    fp = [\n        # 四书五经\n        os.path.join(args.data_dir, \"chinese-poetry\", \"lunyu\", \"lunyu.json\"),\n        os.path.join(args.data_dir, \"chinese-poetry\", \"sishuwujing\", \"*.json\"),\n        # 古体诗\n        os.path.join(args.data_dir, \"chinese-poetry\", \"caocaoshiji\", \"caocao.json\"),\n        os.path.join(args.data_dir, \"chinese-poetry\", \"shijing\", \"shijing.json\"),\n        # 楚辞\n        os.path.join(args.data_dir, \"chinese-poetry\", \"chuci\", \"chuci.json\"),\n        # 诗\n        os.path.join(args.data_dir, \"chinese-poetry\", \"shi\", \"poet*.json\"),\n        # 词\n        os.path.join(args.data_dir, \"chinese-poetry\", \"ci\", \"ci*.json\"),\n        os.path.join(args.data_dir, \"chinese-poetry\", \"nalanxingde\", \"*.json\"),\n        os.path.join(args.data_dir, \"chinese-poetry\", \"wudai\", \"huajianji\", \"*juan.json\"),\n        os.path.join(args.data_dir, \"chinese-poetry\", \"wudai\", \"nantang\", \"poetrys.json\"),\n        # 曲\n        os.path.join(args.data_dir, \"chinese-poetry\", \"yuanqu\", \"yuanqu.json\"),\n    ]\n    fs = [each for f in fp for each in glob.glob(f)]\n    \n    l5 = []\n    dicts = dict()\n    for fi in fs:\n        lines = json.load(open(fi, \"r\", encoding=\"utf-8\"))\n        if isinstance(lines, dict):\n            lines = [lines]\n        for i, line in enumerate(lines):\n            if \"lunyu\" in fi:\n                author = \"孔子\"\n                genre = \"经书\"\n                title = line['chapter']\n                contents = \"\".join(line['paragraphs'])\n            elif \"daxue\" in fi:\n                author = \"曾子\"\n                genre = \"经书\"\n                title = \"大学\"\n                contents = converter.convert(\"\".join(line['paragraphs'])).replace(\"「\", \"“\").replace(\"」\", \"”\")\n            elif \"mengzi\" in fi:\n                author = \"孟子\"\n                genre = \"经书\"\n                title = converter.convert(line['chapter'])\n                contents = converter.convert(\"\".join(line['paragraphs'])).replace(\"「\", \"“\").replace(\"」\", \"”\")\n            elif \"zhongyong\" in fi:\n                author = \"孔伋\"\n                genre = \"经书\"\n                title = \"中庸\"\n                contents = converter.convert(\"\".join(line['paragraphs'])).replace(\"「\", \"“\").replace(\"」\", \"”\")\n            elif \"caocao\" in fi:\n                author = \"曹操\"\n                genre = \"古体诗\"\n                title = line['title']\n                contents = \"\".join(line['paragraphs'])\n            elif \"shijing\" in fi:\n                author = \"诗经\"\n                genre = \"古体诗\"\n                title = line['chapter'] + \"-\" + line['section'] + \"-\" + line['title']\n                contents = \"\".join(line['content'])\n            elif \"chuci\" in fi:\n                author = line['author']\n                genre = \"楚辞\"\n                title = line['section'] + \"-\" + line['title']\n                contents = \"\".join(line['content'])\n            elif \"nalanxingde\" in fi:\n                author = line['author']\n                genre = \"词\"\n                title = line['title']\n                contents = \"\".join(line['para'])\n            elif \"huajianci\" in fi:\n                author = line['author']\n                genre = \"词\"\n                title = line['title']\n                contents = \"\".join(line['paragraphs'])\n            elif \"nantang\" in fi:\n                author = line['author']\n                genre = \"词\"\n                title = line['title']\n                contents = \"\".join(line['paragraphs'])\n            elif \"yuanqu\" in fi:\n                author = line['author']\n                genre = \"曲\"\n                title = line['title']\n                contents = \"\".join(line['paragraphs'])\n            elif \"shi\" in fi:\n                if len(line['paragraphs']) <= 0:\n                    continue\n                author = converter.convert(line['author'])\n                genre = \"五言诗\" if len(line['paragraphs'][0]) == 12 else \"七言诗\"\n                title = converter.convert(line['title'])\n                contents = converter.convert(\"\".join(line['paragraphs']))\n            elif \"ci\" in fi:\n                author = line['author']\n                genre = \"词\"\n                title = line['rhythmic']\n                contents = \"\".join(line['paragraphs'])\n            if genre not in dicts:\n                dicts[genre] = dict()\n            if author not in dicts[genre]:\n                dicts[genre][author] = dict()\n            quantifier = \"篇\" if genre in [\"经书\", \"楚辞\"] else \"首\"\n            prompt = f\"以{author}的风格，写一{quantifier}{genre}，题为{title}\"\n            answers = [{\"answer\": contents, \"score\": 1}]\n            l5.append({\"prompt\": prompt, \"answers\": answers, \"genre\": genre, \"title\": title, \"author\": author})\n            dicts[genre][author][title] = contents\n    t2 = time.time()\n    logger.info(f\"length: {len(l5)}, # different lengths: {len(dicts)}, time taken: {t2-t1} s\")\n    data = []\n    fo = os.path.join(args.output_dir, \"chinese_poetry.jsonl\")\n    with open(fo, \"w\", encoding=\"utf-8\") as w:\n        for l in tqdm(enumerate(l5), desc=\"Processing Chinese Poetry\"):\n            genre = l['genre']\n            author = l['author']\n            title = l['title']\n            prompt = l['prompt']\n            answers = l['answers']\n            # 同作者其他作品-2\n            titles_tmp = set(dicts[genre][author].keys())\n            titles_tmp.remove(title)\n            if len(titles_tmp) > 0:\n                t = random.choice(list(titles_tmp))\n                answers.append({\"answer\": dicts[genre][author][t], \"score\": 0})\n            # 同体裁其他作者其他作品-1\n            authors_tmp = set(dicts[genre].keys())\n            authors_tmp.remove(author)\n            a = random.choice(list(authors_tmp))\n            t = random.choice(list(dicts[genre][a].keys()))\n            answers.append({\"answer\": dicts[genre][a][t], \"score\": -1})\n            # 不同体裁作品-0\n            genres_tmp = set(dicts.keys())\n            genres_tmp.remove(genre)\n            g = random.choice(list(genres_tmp))\n            a = random.choice(list(dicts[g].keys()))\n            t = random.choice(list(dicts[g][a].keys()))\n            answers.append({\"answer\": dicts[g][a][t], \"score\": -2})\n            w.write(json.dumps({\"prompt\": prompt, \"answers\": answers, \"prefix\": \"\"}, ensure_ascii=False)+'\\n')\n            data.append({\"prompt\": prompt, \"answers\": answers, \"prefix\": \"\"})\n    logger.info(f\"length: {len(data)}, time taken: {time.time()-t2} s\")\n\n    return data\n    \n\ndef baike_qa_2019(args, tokenizer):\n    ''' 百科问答数据集，每个问题只有一个答案，可构造sft训练集\n\n    :param args:\n    :param tokenizer:\n\n    :return: processed json list\n    '''\n    fs = glob.glob(os.path.join(args.data_dir, \"baike_qa2019\", \"baike_qa_*.json\"))\n    fo = os.path.join(args.output_dir, \"baike_qa.jsonl\")\n    data = []\n    t = time.time()\n    with open(fo, \"w\", encoding=\"utf-8\") as w:\n        for f in fs:\n            with open(f, \"r\", encoding=\"utf-8\") as r:\n                while True:\n                    line = r.readline()\n                    if not line:\n                        break\n                    item = json.loads(line.strip(\"\\n\"))\n                    question = clean_text(item['title'] if len(item['title']) > len(item['desc']) else item['desc'])\n                    prompt = question\n                    prefix = \"答：\"\n                    answer = clean_text(item['answer'])\n                    answers = [{\"answer\": answer, \"score\": 1}]\n                    w.write(json.dumps({\"prompt\": prompt, \"answers\": answers, \"prefix\": prefix}, ensure_ascii=False)+'\\n')\n                    data.append({\"prompt\": prompt, \"answers\": answers, \"prefix\": prefix})\n    logger.info(f\"length: {len(data)}, time taken: {time.time()-t} s\")\n\n    return data\n\n\ndef get_parser():\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument(\"--data_dir\", type=str, required=True)\n    parser.add_argument(\"--output_dir\", type=str, required=True)\n    parser.add_argument(\"--model_name_or_path\", type=str, required=True)\n\n    parser.add_argument(\"--seed\", type=int, default=42)\n\n    args = parser.parse_args()\n\n    return args\n\n\ndef main():\n    args = get_parser()\n    logger.info(f\"Parameters: {args}\")\n    # set random seed\n    set_seed(args.seed)\n    # load tokenizer\n    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_cache=False, trust_remote_code=True)\n\n    # process raw datasets\n    data_weibo = weibo_summary_comment(args, tokenizer)\n    data_couplets = couplets(args, tokenizer)\n    data_zhidao = zhidao(args, tokenizer)\n    data_chinese_classical = chinese_classical(args, tokenizer)\n    data_chinese_poetry = chinese_poetry(args, tokenizer)\n    data_baike = baike_qa_2019(args, tokenizer)\n\n    # merge processed datasets\n    data = data_weibo + data_couplets + data_zhidao + data_chinese_classical + data_chinese_poetry + data_baike\n    random.shuffle(data)\n    fo = os.path.join(args.output_dir, \"train_data_external_v1.jsonl\")\n    with open(fo, \"w\", encoding=\"utf-8\") as w:\n        for d in data:\n            w.write(json.dumps(d, ensure_ascii=False)+'\\n')\n\n    fo = os.path.join(args.output_dir, \"dev_data_external_v1.jsonl\")\n    with open(fo, \"w\", encoding=\"utf-8\") as w:\n        for d in data[:10000]:\n            w.write(json.dumps(d, ensure_ascii=False)+'\\n')\n\n    logger.info(\"Finished saving processed train & dev files\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/eval_pretrain.py",
    "content": "\nimport sys\nsys.path.insert(0, \"/root/autodl-tmp/Code/RLHF\")\nsys.path.insert(0, \"/mnt/sfevol775196/sunzeye273/Code/chatgpt\")\n# sys.path.insert(0, \"/mnt/share-pa002-vol682688-prd/sunzeye273/Code/chatgpt\")\nsys.path.insert(0, \"/mnt/pa002-28359-vol543625-private/Code/chatgpt\")\nimport os\nimport argparse\nimport json\nimport numpy as np\nimport torch\nimport collections\n\nfrom tqdm import tqdm\nfrom torch.utils.data import DataLoader, SequentialSampler\nfrom torchmetrics.text.perplexity import Perplexity\nfrom transformers.generation.logits_process import LogitsProcessor\nfrom transformers.generation.utils import LogitsProcessorList\n\nfrom src.data.data import (\n    OCNLIDataset,\n    CMNLIDataset,\n    CHIDDataset,\n    CMRCDataset,\n    CLUEWSCDataset,\n    C3Dataset,\n    AFQMCDataset,\n    CSLDataset,\n    IFLYTEKDataset,\n    TNEWSDataset,\n    CEvalDataset,\n    MMLUDataset,\n)\nfrom src.utils import RESOURCE_PATH, load_tokenizer_and_model, load_checkpoint\nfrom src.utils.file_utils import set_seed, print_rank_0\n\n\nDATASET = {\n    \"ceval\": CEvalDataset,\n    \"mmlu\": MMLUDataset,\n    # NLI\n    \"ocnli\": OCNLIDataset,\n    \"cmnli\": CMNLIDataset,\n    # Cloze and completion\n    \"chid\": CHIDDataset,\n    # MRC\n    \"cmrc2018\": CMRCDataset,\n    # Winograd\n    \"cluewsc2020\": CLUEWSCDataset,\n    # common sense reasoning\n    \"c3\": C3Dataset,\n    # Text Classification\n    \"tnews\": TNEWSDataset,\n    \"iflytek\": IFLYTEKDataset,\n    \"afqmc\": AFQMCDataset,\n    \"csl\": CSLDataset\n}\n\n\ndef get_parser():\n    parser = argparse.ArgumentParser()\n    \n    parser.add_argument(\"--data_dir\", type=str, required=True)\n    parser.add_argument(\"--output_dir\", type=str, required=True)\n    parser.add_argument(\"--model_name_or_path\", type=str, required=True)\n    parser.add_argument(\"--task\", type=str, required=True)\n\n    parser.add_argument(\"--seed\", type=int, default=42)\n    # parser.add_argument(\"--local_rank\", type=int, default=0)\n    # parser.add_argument(\"--multi_card\", action=\"store_true\")\n    parser.add_argument(\"--bits\", type=int, default=16)\n    parser.add_argument(\"--device_map\", type=str, default=None, help=\"device map to allocate model,\"\n                                                                     \"[None] means cpu\"\n                                                                     \"[0, 1, 2, ...] numbers mean single-card\"\n                                                                     \"[auto, balanced, balanced_low_0] means multi-card\")\n    parser.add_argument(\"--max_length\", type=int, default=2048)\n    parser.add_argument(\"--max_length_generation\", type=int, default=1, help=\"Maximum number of newly generated tokens\")\n    parser.add_argument(\"--checkpoint\", type=str)\n    parser.add_argument(\"--low_cpu_mem_usage\", action=\"store_true\", help=\"whether to enable low cpu memory usage\"\n                                                                         \"when loading model\")\n\n    # eval\n    parser.add_argument(\"--eval_filename\", type=str, default=None)\n    parser.add_argument(\"--train_filename\", type=str, default=None)\n    parser.add_argument(\"--submission_filename\", type=str, default=None)\n    parser.add_argument(\"--eval_batch_size\", type=int, default=4)\n    parser.add_argument(\"--max_few_shot\", type=int, default=15, help=\"Maximum number of examples in few-shot evaulation\")\n    parser.add_argument(\"--cot\", action=\"store_true\", help=\"Whether to use Chain of Thought in evaluation\")\n    parser.add_argument(\"--do_sample\", action=\"store_true\")\n    parser.add_argument(\"--num_return_sequences\", type=int, default=1)\n    parser.add_argument(\"--top_k\", type=int, default=10)\n    parser.add_argument(\"--top_p\", type=float, default=0.8)\n    parser.add_argument(\"--temperature\", type=float, default=0.8)\n    parser.add_argument(\"--repetition_penalty\", type=float, default=1.0)\n\n    args = parser.parse_args()\n    \n    return args\n\n\ndef extract_cot_answer(line, response):\n    #TODO: to be implemented\n    pass\n\n\ndef main():\n    args = get_parser()\n    print_rank_0(f\"Parameters: {args}\")\n\n    set_seed(args.seed)\n\n    # load model and tokenizer\n    tokenizer, model, eos_token_id = load_tokenizer_and_model(args)\n\n    if args.checkpoint is not None:\n        suffix = args.checkpoint.split(os.sep)[-2] + \"_\"\n        load_checkpoint(args, model)\n    else:\n        suffix = \"\"\n\n    print_rank_0(f\"Finished loading model and tokenizer\")\n\n    # Set up the datasets\n    dataset = DATASET.get(args.task, None)\n    if dataset is None:\n        raise ValueError(f\"Unsupported task: {args.task}\")\n    train_filename = os.path.join(args.data_dir, args.train_filename) if args.train_filename is not None else None\n    dev_dataset = dataset(args, os.path.join(args.data_dir, args.eval_filename),\n                          tokenizer, train_filename)\n\n    # Set up the metric\n    perplexity = Perplexity(ignore_index=tokenizer.pad_token_id)\n\n    def preprocess_logits_for_metrics(logits, labels):\n        labels = labels.detach().cpu()\n        probs = torch.softmax(logits, dim=-1).detach().cpu().to(torch.float32)\n        ppls = []\n        for i in range(probs.shape[0]):\n            ppl = perplexity(probs[i:i+1], labels[i:i+1])\n            ppls.append(ppl)\n\n        return torch.stack(ppls)\n\n    def calculate_f1(pred_text, label_text):\n        pred_tokens = tokenizer(pred_text, add_special_tokens=False, return_attention_mask=False, return_token_type_ids=False, return_tensors=\"pt\")['input_ids'][0].tolist()\n        label_tokens = tokenizer(label_text, add_special_tokens=False, return_attention_mask=False, return_token_type_ids=False, return_tensors=\"pt\")['input_ids'][0].tolist()\n        common = collections.Counter(pred_tokens) & collections.Counter(label_tokens)\n        num_same = sum(common.values())\n        if len(pred_tokens) == 0 or len(label_tokens) == 0:\n            return int(pred_tokens == label_tokens)\n        if num_same == 0:\n            return 0\n        precision = num_same / len(pred_tokens)\n        recall = num_same / len(label_tokens)\n        f1 = (2 * precision * recall) / (precision + recall)\n\n        return f1\n\n    device = f\"cuda:{torch.cuda.current_device()}\" if torch.cuda.is_available() else \"cpu\"\n    model.eval()\n\n    if args.train_filename is None:\n        output_filename = os.path.join(args.output_dir, f\"{args.task}_{args.eval_filename}_zero-shot_{args.max_length}_{suffix}eval_result.jsonl\")\n    else:\n        assert args.max_few_shot > 0\n        output_filename = os.path.join(args.output_dir, f\"{args.task}_{args.eval_filename}_{args.max_few_shot}-shot_{args.max_length}_{suffix}eval_result.jsonl\")\n\n    if args.task in [\"cmrc2018\"]:\n        # text_generator = TextGenerationPipeline(model, tokenizer, device=device)\n        ems = []\n        f1s = []\n        with open(output_filename, \"w\", encoding=\"utf-8\") as w:\n            with torch.no_grad():\n                for dev_data in tqdm(dev_dataset.post_list, desc=\"Generation\"):\n                    prompt = dev_data['prompt']\n                    label = dev_data['label']\n                    if \"glm\" in args.model_name_or_path.lower():\n                        prompt += tokenizer.mask_token\n                        inputs = tokenizer(prompt, return_tensors=\"pt\")\n                        inputs = tokenizer.build_inputs_for_generation(inputs, max_gen_length=args.max_length + args.max_length_generation)\n                        inputs = inputs.to(device)\n                        outputs = model.generate(**inputs,\n                                                 max_new_tokens=args.max_length_generation,\n                                                 eos_token_id=eos_token_id,\n                                                 pad_token_id=tokenizer.pad_token_id,\n                                                 do_sample=False,\n                                                 num_return_sequences=args.num_return_sequences,\n                                                 top_p=args.top_p,\n                                                 temperature=args.temperature)\n                    else:\n                        inputs = tokenizer(prompt, add_special_tokens=False, return_token_type_ids=False, return_tensors=\"pt\")\n                        inputs = inputs.to(device)\n                        outputs = model.generate(**inputs,\n                                                 max_new_tokens=args.max_length_generation,\n                                                 pad_token_id=tokenizer.pad_token_id,\n                                                 do_sample=False,\n                                                 num_return_sequences=args.num_return_sequences,\n                                                 top_p=args.top_p,\n                                                 temperature=args.temperature)\n                        # outputs = text_generator(prompt, max_length=args.max_length_generation,\n                        #                          do_sample=True, num_return_sequences=args.num_return_sequences,\n                        #                          top_p=args.top_p, temperature=args.temperature)\n                        # results = [output['generated_text'].split(\"答:\", maxsplit=1)[1].replace(tokenizer.eos_token, \"\").replace(tokenizer.pad_token, \"\") for output in outputs]\n                    results = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n                    results = [result.split(\"答:\", maxsplit=1)[1] for result in results]\n\n                    # metrics calculation\n                    em_max = -1\n                    f1_max = -1\n                    for l in label:\n                        for pred_text in results:\n                            label_text = l['text']\n                            em = 1 if pred_text == label_text else 0\n                            f1 = calculate_f1(pred_text, label_text)\n                            w.write(json.dumps({\"prompt\": prompt, \"label\": label_text,\n                                                \"pred\": pred_text, \"em\": em, \"f1\": f1}, ensure_ascii=False)+\"\\n\")\n                            if em > em_max:\n                                em_max = em\n                            if f1 > f1_max:\n                                f1_max = f1\n                    ems.append(em_max)\n                    f1s.append(f1_max)\n\n        print_rank_0(f\"em={np.mean(ems)}, f1={np.mean(f1s)}\")\n    elif args.task in [\"ceval\"]:\n        results = dict()\n        with torch.no_grad():\n            for dev_data in tqdm(dev_dataset, desc=\"C-Eval Evaluation\"):\n                subject_name_key = dev_data['subject_name_key']\n                if subject_name_key not in results:\n                    results[subject_name_key] = list()\n                if \"chatglm\" in args.model_name_or_path.lower():\n                    logits_processor = LogitsProcessorList()\n                    if \"chatglm2\" in args.model_name_or_path.lower():\n                        class InvalidScoreLogitsProcessor(LogitsProcessor):\n                            def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:\n                                if torch.isnan(scores).any() or torch.isinf(scores).any():\n                                    scores.zero_()\n                                    scores[..., 5] = 5e4\n                                return scores\n                    else:\n                        class InvalidScoreLogitsProcessor(LogitsProcessor):\n                            def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:\n                                if torch.isnan(scores).any() or torch.isinf(scores).any():\n                                    scores.zero_()\n                                    scores[..., 20005] = 5e4\n                                return scores\n                    logits_processor.append(InvalidScoreLogitsProcessor())\n                    input_ids = dev_data['input_ids'].to(device)\n                    outputs = model.generate(input_ids=input_ids,\n                                             max_new_tokens=args.max_length_generation,\n                                             do_sample=args.do_sample,\n                                             num_return_sequences=args.num_return_sequences,\n                                             top_p=args.top_p,\n                                             temperature=args.temperature,\n                                             repetition_penalty=args.repetition_penalty,\n                                             logits_processor=logits_processor,\n                                             output_scores=not args.cot,\n                                             return_dict_in_generate=not args.cot)\n                elif \"qwen\" in args.model_name_or_path.lower():\n                    input_ids = dev_data['input_ids'].to(device)\n                    outputs = model.generate(input_ids=input_ids,\n                                             max_new_tokens=args.max_length_generation,\n                                             do_sample=args.do_sample,\n                                             num_return_sequences=args.num_return_sequences,\n                                             top_p=args.top_p,\n                                             temperature=args.temperature,\n                                             repetition_penalty=args.repetition_penalty,\n                                             output_scores=not args.cot,\n                                             return_dict_in_generate=not args.cot)\n                else:\n                    input_ids = dev_data['input_ids'].to(device)\n                    attention_mask = dev_data['attention_mask'].to(device)\n                    outputs = model.generate(input_ids=input_ids,\n                                             attention_mask=attention_mask,\n                                             max_new_tokens=args.max_length_generation,\n                                             do_sample=args.do_sample,\n                                             num_return_sequences=args.num_return_sequences,\n                                             top_p=args.top_p,\n                                             temperature=args.temperature,\n                                             repetition_penalty=args.repetition_penalty,\n                                             output_scores=not args.cot,\n                                             return_dict_in_generate=not args.cot)\n\n                # output processing and answer extraction\n                if args.cot:\n                    outputs = outputs['sequences'].tolist()[0][len(input_ids[\"input_ids\"][0]):]\n                    response = tokenizer.decode(outputs)\n                    # response, _ = model.chat(tokenizer, dev_data['question'], history=dev_data['history'],\n                    #                          do_sample=False, )\n                    response = response.strip()\n                    # ans, direct_extract = extract_cot_answer(dev_data, response)\n                else:\n                    logits = outputs['scores'][0].flatten()\n                    pred = torch.tensor(\n                                [\n                                    logits[tokenizer.encode(\"A\", add_special_tokens=False)[0]],\n                                    logits[tokenizer.encode(\"B\", add_special_tokens=False)[0]],\n                                    logits[tokenizer.encode(\"C\", add_special_tokens=False)[0]],\n                                    logits[tokenizer.encode(\"D\", add_special_tokens=False)[0]],\n                                ]\n                            ).argmax().detach().cpu().tolist()\n                    pred = {0: \"A\", 1: \"B\", 2: \"C\", 3: \"D\"}[pred]\n                    # correct = 1 if pred == label else 0\n                    results[subject_name_key].append((dev_data['id'], dev_data['answer'], pred))\n\n        # metrics calculation\n        subject_mapping = json.load(open(os.path.join(RESOURCE_PATH, \"eval\", \"ceval\", \"subject_mapping.json\")))\n        with open(output_filename, \"w\", encoding=\"utf-8\") as w:\n            result_dict = dict()\n            acc_dict = dict()\n            for subject_name_key, vals in results.items():\n                if subject_name_key not in result_dict:\n                    result_dict[subject_name_key] = dict()\n                domain = subject_mapping[subject_name_key][2]\n                if domain not in acc_dict:\n                    acc_dict[domain] = {\"ct\": 0, \"correct\": 0}\n                for id_, label, pred in vals:\n                    result_dict[subject_name_key][str(id_)] = pred\n                    acc_dict[domain]['correct'] += 1 if pred == label else 0\n                    acc_dict[domain]['ct'] += 1\n                    w.write(json.dumps({\"subject_name_key\": subject_name_key, \"id\": id_,\n                                        \"pred\": pred, \"label\": label}, ensure_ascii=False)+\"\\n\")\n\n        # if submission file is not none, then there is no label to calculate accuracy\n        if args.submission_filename is not None:\n            json.dump(result_dict, open(os.path.join(args.output_dir, args.submission_filename), \"w\", encoding=\"utf-8\"),\n                      ensure_ascii=False)\n            print_rank_0(f\"Finished saving C-Eval Evaluation Result\")\n        else:\n            ct = 0\n            correct = 0\n            for domain, val in acc_dict.items():\n                ct += val['ct']\n                correct += val['correct']\n                print_rank_0(f\"[C-Eval Evaluation Result] domain: {domain}, acc: {val['correct'] / val['ct']}\")\n            print_rank_0(f\"[C-Eval Evaluation Result] total acc: {correct / ct}\")\n    elif args.task in [\"mmlu\"]:\n        results = dict()\n        with torch.no_grad():\n            for dev_data in tqdm(dev_dataset, desc=\"MMLU Evaluation\"):\n                subject_name_key = dev_data['subject_name_key']\n                if subject_name_key not in results:\n                    results[subject_name_key] = list()\n                if \"chatglm\" in args.model_name_or_path.lower():\n                    logits_processor = LogitsProcessorList()\n                    if \"chatglm2\" in args.model_name_or_path.lower():\n                        class InvalidScoreLogitsProcessor(LogitsProcessor):\n                            def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:\n                                if torch.isnan(scores).any() or torch.isinf(scores).any():\n                                    scores.zero_()\n                                    scores[..., 5] = 5e4\n                                return scores\n                    else:\n                        class InvalidScoreLogitsProcessor(LogitsProcessor):\n                            def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:\n                                if torch.isnan(scores).any() or torch.isinf(scores).any():\n                                    scores.zero_()\n                                    scores[..., 20005] = 5e4\n                                return scores\n                    logits_processor.append(InvalidScoreLogitsProcessor())\n                    input_ids = dev_data['input_ids'].to(device)\n\n                    outputs = model.generate(input_ids=input_ids,\n                                             max_new_tokens=args.max_length_generation,\n                                             do_sample=args.do_sample,\n                                             num_return_sequences=args.num_return_sequences,\n                                             top_p=args.top_p,\n                                             temperature=args.temperature,\n                                             repetition_penalty=args.repetition_penalty,\n                                             logits_processor=logits_processor,\n                                             output_scores=True,\n                                             return_dict_in_generate=True)\n                elif \"qwen\" in args.model_name_or_path.lower():\n                    input_ids = dev_data['input_ids'].to(device)\n                    outputs = model.generate(input_ids=input_ids,\n                                             max_new_tokens=args.max_length_generation,\n                                             do_sample=args.do_sample,\n                                             num_return_sequences=args.num_return_sequences,\n                                             top_p=args.top_p,\n                                             temperature=args.temperature,\n                                             repetition_penalty=args.repetition_penalty,\n                                             output_scores=True,\n                                             return_dict_in_generate=True)\n                else:\n                    input_ids = dev_data['input_ids'].to(device)\n                    attention_mask = dev_data['attention_mask'].to(device)\n                    outputs = model.generate(input_ids=input_ids,\n                                             attention_mask=attention_mask,\n                                             max_new_tokens=args.max_length_generation,\n                                             do_sample=args.do_sample,\n                                             num_return_sequences=args.num_return_sequences,\n                                             top_p=args.top_p,\n                                             temperature=args.temperature,\n                                             repetition_penalty=args.repetition_penalty,\n                                             output_scores=True,\n                                             return_dict_in_generate=True)\n\n                # output processing and answer extraction\n                logits = outputs['scores'][0].flatten()\n                pred = torch.tensor(\n                    [\n                        logits[tokenizer.encode(\"A\", add_special_tokens=False)[0]],\n                        logits[tokenizer.encode(\"B\", add_special_tokens=False)[0]],\n                        logits[tokenizer.encode(\"C\", add_special_tokens=False)[0]],\n                        logits[tokenizer.encode(\"D\", add_special_tokens=False)[0]],\n                    ]\n                ).argmax().detach().cpu().tolist()\n                pred = {0: \"A\", 1: \"B\", 2: \"C\", 3: \"D\"}[pred]\n                # correct = 1 if pred == label else 0\n                results[subject_name_key].append((dev_data['answer'], pred))\n\n        # metrics calculation\n        subject_mapping = json.load(open(os.path.join(RESOURCE_PATH, \"eval\", \"mmlu\", \"subject_mapping.json\")))\n        with open(output_filename, \"w\", encoding=\"utf-8\") as w:\n            acc_dict = dict()\n            for subject_name_key, vals in results.items():\n                domain = subject_mapping[subject_name_key][1]\n                if domain not in acc_dict:\n                    acc_dict[domain] = {\"ct\": 0, \"correct\": 0}\n                for label, pred in vals:\n                    # result_dict[subject_name_key] = pred\n                    acc_dict[domain]['correct'] += 1 if pred == label else 0\n                    acc_dict[domain]['ct'] += 1\n                    w.write(json.dumps({\"subject_name_key\": subject_name_key,\n                                        \"pred\": pred, \"label\": label}, ensure_ascii=False)+\"\\n\")\n        ct = 0\n        correct = 0\n        for domain, val in acc_dict.items():\n            ct += val['ct']\n            correct += val['correct']\n            print_rank_0(f\"[MMLU Evaluation Result] domain: {domain}, acc: {val['correct'] / val['ct']}\")\n        print_rank_0(f\"[MMLU Evaluation Result] total acc: {correct / ct}\")\n    else:\n        sampler = SequentialSampler(dev_dataset)\n        dev_dataloader = DataLoader(dev_dataset, sampler=sampler, batch_size=args.eval_batch_size)\n\n        ppl_list = []\n        input_ids_list = []\n        label_list = []\n        ls_list = []\n\n        with torch.no_grad():\n            for batch in tqdm(dev_dataloader, desc=\"Evaluation\"):\n                input_ids = batch['input_ids'].squeeze(1).to(device)\n                attention_mask = batch['attention_mask'].squeeze(1).to(device)\n                labels = batch['labels'].squeeze(1).to(device)\n                out = model(input_ids, attention_mask=attention_mask)\n                ppls = preprocess_logits_for_metrics(out.logits, labels)\n                input_ids_list.extend(batch['input_ids'].detach().cpu().tolist())\n                ppl_list.extend(ppls.detach().cpu().tolist())\n                label_list.extend(batch['label_str'])\n                if args.task in ['chid', 'c3', 'iflytek', 'tnews']:\n                    ls = np.array(batch['candidates']).transpose().tolist()\n                    ls_list.extend(ls)\n                else:\n                    vals = list(dev_dataset.label_dict.values())\n                    ls_list.extend([vals]*input_ids.shape[0])\n\n        ct = 0\n        ct_acc = 0\n        ppls = []\n        with open(output_filename, \"w\", encoding=\"utf-8\") as w:\n            for i, (input_ids, label, ls, ppl) in enumerate(zip(input_ids_list, label_list, ls_list, ppl_list)):\n                ppls.append(ppl)\n                prompt = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]\n                if i % len(ls) == len(ls) - 1:\n                    lidx = ls.index(label)\n                    if np.argmin(ppls) == lidx:\n                        ct_acc += 1\n                    ct += 1\n                    # cur_label = None\n                    ppls = []\n                w.write(json.dumps({\"prompt\": prompt, \"pred\": float(ppl), \"label\": label}, ensure_ascii=False) + \"\\n\")\n\n        print_rank_0(f\"ppl={ct_acc/ct}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/models/__init__.py",
    "content": "from src.models.reward import RewardModel\nfrom src.models.loss import PairWiseLoss\n# from src.models.llama import LlamaForCausalLM, LlamaModel, LlamaPreTrainedModel"
  },
  {
    "path": "src/models/loss.py",
    "content": "\nimport torch\nimport torch.nn as nn\n\n\nclass PairWiseLoss(nn.Module):\n    \"\"\"\n    Pairwise Loss for Reward Model\n    \"\"\"\n\n    def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:\n        probs = torch.sigmoid(chosen_reward - reject_reward)\n        log_probs = torch.log(probs)\n        loss = -log_probs.mean()\n        return loss"
  },
  {
    "path": "src/models/ppo.py",
    "content": "# Copyright 2022 CarperAI & The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n# NOTE: This file contains a modified version of the `PreTrainedModelWrapper` class from\n# HuggingFace's `trl` library. The original source code can be found here:\n# https://github.com/lvwerra/trl/blob/78c13226bf8ea1ccd9b1c091f03a938098521f6c/trl/models/modeling_base.py\n\nimport gc\nimport inspect\nimport json\nimport os\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport transformers\n\nfrom copy import deepcopy\nfrom typing import Any, Dict, List, Optional, Tuple, Union\nfrom dataclasses import dataclass\nfrom transformers.modeling_outputs import ModelOutput\nfrom transformers.models.bloom import modeling_bloom\nfrom transformers.models.opt import modeling_opt\nfrom huggingface_hub import hf_hub_download\nfrom peft import (\n    LoraConfig,\n    get_peft_model\n)\n\nfrom src.utils.modeling_utils import (\n    hf_get_decoder,\n    hf_get_decoder_blocks,\n    hf_get_decoder_final_norm,\n    hf_get_hidden_size,\n    hf_get_lm_head,\n    hf_get_num_hidden_layers,\n    make_head\n)\n\n\nclass PreTrainedModelWrapper(nn.Module, transformers.utils.PushToHubMixin):\n    \"\"\"A wrapper around `transformers.PreTrainedModel`\n\n    Reference: @younesbelkada's `PreTrainedModelWrapper`\n    https://github.com/lvwerra/trl/blob/4f5c16fafde42d9aca971952bcdcc1f5a0a68cf0/trl/models/modeling_base.py#L2\n\n    Attributes:\n        _auto_model_parent_class (transformers.AutoModel): The `transformers.AutoModel`\n            type to base the wrapping behavior off of, e.g. `transformers.AutoModelForCausalLM`.\n        _supported_modules (List[str]): A list of attribute names for modules of\n            the underlying architecture model. This is used, for example, to save\n            and load any additional modules by manipulating the state dict.\n        _supported_args (List[str]): A list of arguments specific to the underlying\n            architecture to separate from arguments that are supported by the\n            parent `AutoModel` class. Any arguments that are not supported by the\n            underlying model will be passed to the parent `AutoModel` class.\n    \"\"\"\n\n    _auto_model_parent_class: transformers.AutoModel = None\n    _supported_modules: List[str] = None\n    # TODO (jon-tow): Supported args should come from a `PretrainedConfig` of the\n    # specific underlying type similar to how config instances can be used to instantiate\n    # `transformers.PreTrainedModel`s.\n    _supported_args: List[str] = None\n\n    def __init__(self, base_model: Optional[transformers.PreTrainedModel] = None, **kwargs):\n        super().__init__()\n        self.base_model = base_model\n        # cache `forward` args for general use (avoids incompatible args across architectures)\n        self.forward_kwargs = inspect.getfullargspec(self.base_model.forward).args\n\n    @classmethod\n    def _split_kwargs(cls, kwargs: Dict[str, Any]):\n        \"\"\"Separates the kwargs from the supported arguments within `supported_args`\n        and those that are not\n        \"\"\"\n        supported_kwargs = {}\n        unsupported_kwargs = {}\n        for key, value in kwargs.items():\n            if key in cls._supported_args:\n                supported_kwargs[key] = value\n            else:\n                unsupported_kwargs[key] = value\n        return supported_kwargs, unsupported_kwargs\n\n    @classmethod\n    def from_config(cls, config: transformers.PretrainedConfig, **kwargs):\n        \"\"\"Instantiate the pretrained pytorch model from a configuration.\n\n        Args:\n            config (transformers.PretrainedConfig): The configuration to use to\n                instantiate the base model.\n\n        NOTE: Loading a model from its configuration file does **not** load the\n        model weights. It only affects the model's configuration. Use\n        `~transformers.AutoModel.from_pretrained` to load the model weights.\n        \"\"\"\n        if kwargs is not None:\n            wrapped_model_kwargs, from_config_kwargs = cls._split_kwargs(kwargs)\n        else:\n            from_config_kwargs = {}\n            wrapped_model_kwargs = {}\n        base_model = cls._auto_model_parent_class.from_config(config, **from_config_kwargs)\n        model = cls(base_model, **wrapped_model_kwargs)\n        return model\n\n    @classmethod\n    def from_pretrained(  # noqa: max-complexity\n        cls,\n        pretrained_model_name_or_path: Union[str, transformers.PreTrainedModel],\n        *model_args,\n        **kwargs,\n    ):\n        \"\"\"Instantiate a pretrained pytorch model from a pretrained model configuration.\n        This method is a wrapper around `transformers.PreTrainedModel.from_pretrained`.\n        Please refer to the documentation of `transformers.PreTrainedModel.from_pretrained`\n        for more information.\n\n        Args:\n            pretrained_model_name_or_path (str or `transformers.PreTrainedModel`):\n                The identifier of the pretrained model to load or the pretrained model itself.\n            *model_args (sequence of positional arguments, *optional*):\n                All remaining positional arguments will be passed to the `_auto_model_parent_class`.\n            **kwargs (dict, *optional*):\n                Dictionary of keyword arguments to pass to both the underlying `_auto_model_parent_class`\n                call (e.g. `transformers.AutoModelForCausalLM.from_pretrained`) and the specific\n                instance of the wrapped model.\n\n        NOTE: You must pass in arguments specific to the wrapped model as keyword arguments.\n        \"\"\"\n        if kwargs is not None:\n            wrapped_model_kwargs, from_pretrained_kwargs = cls._split_kwargs(kwargs)\n        else:\n            from_pretrained_kwargs = {}\n            wrapped_model_kwargs = {}\n\n        if isinstance(pretrained_model_name_or_path, str):\n            # Load the base model using the `transformers` AutoClass (e.g. AutoModelForCausalLM)\n            base_model = cls._auto_model_parent_class.from_pretrained(\n                pretrained_model_name_or_path, *model_args, **from_pretrained_kwargs\n            )\n        elif isinstance(pretrained_model_name_or_path, transformers.PreTrainedModel):\n            base_model = pretrained_model_name_or_path\n        else:\n            raise ValueError(\n                f\"Invalid type for `base_model_name_or_path`: {type(pretrained_model_name_or_path)}\"\n                \"Expected `str` or `transformers.PreTrainedModel`.\"\n            )\n        config = from_pretrained_kwargs.get(\"config\", None)\n        if config is not None:\n            base_model.config.lora_rank = config.train.lora_rank\n            base_model.config.lora_alpha = config.train.lora_alpha\n            base_model.config.lora_train_bias = config.train.lora_train_bias\n        model = cls(base_model, **wrapped_model_kwargs)\n\n        if isinstance(pretrained_model_name_or_path, str):\n            filename = os.path.join(pretrained_model_name_or_path, \"pytorch_model.bin\")\n            sharded_index_filename = os.path.join(pretrained_model_name_or_path, \"pytorch_model.bin.index.json\")\n            is_sharded = False\n\n            if not os.path.exists(filename):\n                try:\n                    filename = hf_hub_download(pretrained_model_name_or_path, \"pytorch_model.bin\")\n                # Sharded\n                except Exception:\n                    if os.path.exists(sharded_index_filename):\n                        index_file_name = sharded_index_filename\n                    else:\n                        index_file_name = hf_hub_download(\n                            pretrained_model_name_or_path,\n                            \"pytorch_model.bin.index.json\",\n                        )\n                    with open(index_file_name, \"r\") as f:\n                        index = json.load(f)\n                    # Collect files containing weights from supported modules\n                    files_to_download = set()\n                    for k, v in index[\"weight_map\"].items():\n                        if any([module in k for module in cls._supported_modules]):\n                            files_to_download.add(v)\n                    is_sharded = True\n\n            if is_sharded:\n                # Merge each shard into a state dict\n                # TODO: Optimize this to avoid wasting RAM\n                state_dict = {}\n                for shard_file in files_to_download:\n                    filename = os.path.join(pretrained_model_name_or_path, shard_file)\n                    # Download if shard file doesn't exist locally\n                    if not os.path.exists(filename):\n                        filename = hf_hub_download(pretrained_model_name_or_path, shard_file)\n                    state_dict.update(torch.load(filename, map_location=\"cpu\"))\n            else:\n                state_dict = torch.load(filename, map_location=\"cpu\")\n        else:\n            state_dict = pretrained_model_name_or_path.state_dict()\n\n        model.post_init(state_dict=state_dict)\n        return model\n\n    def save_pretrained(self, *args, **kwargs):\n        \"\"\"Save the pretrained model to a directory. This method is a wrapper\n        around `transformers.PreTrainedModel.save_pretrained`. Please refer to\n        the documentation of `transformers.PreTrainedModel.save_pretrained` for\n        more information.\n\n        Args:\n            *args (`list`, *optional*):\n                Positional arguments passed along to the underlying model's\n                `save_pretrained` method.\n            **kwargs (`dict`, *optional*):\n                Keyword arguments passed along to the underlying model's\n                `save_pretrained` method.\n        \"\"\"\n        state_dict = kwargs.pop(\"state_dict\", None)\n        if state_dict is None:\n            state_dict = self.state_dict()\n            kwargs[\"state_dict\"] = state_dict\n\n        return self.base_model.save_pretrained(*args, **kwargs)\n\n    def state_dict(self, *args, **kwargs):\n        \"\"\"Return the state_dict of the pretrained model.\"\"\"\n        raise NotImplementedError\n\n    def post_init(self, *args, **kwargs):\n        \"\"\"Post initialization method. This method is called after the model is\n        instantiated and loaded from a checkpoint. It can be used to perform\n        additional operations such as loading the state_dict.\n        \"\"\"\n        raise NotImplementedError\n\n    def get_compatible_forward_kwargs(self, **kwargs) -> Dict[str, Any]:\n        \"\"\"Filter out arguments not supported by the specific instance of\n        `base_model.transformer.forward`\n        \"\"\"\n        # FIXME: This is a hack to get around the fact that the `transformers`\n        # architectures we use don't have a consistent API for `forward` parameters.\n        return {k: v for k, v in kwargs.items() if k in self.forward_kwargs}\n\n\n# KL Controllers\nclass AdaptiveKLController:\n    \"\"\"Adaptive KL Controller as described in Ziegler et al. \"Fine-Tuning Language Models from Human Preferences\"\n    Reference: Section 2.2 https://arxiv.org/pdf/1909.08593.pdf#page=2\n    Source: https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_policy.py\n    \"\"\"\n\n    def __init__(self, init_kl_coef: float, target: float, horizon: int):\n        self.value = init_kl_coef\n        self.target = target\n        self.horizon = horizon\n\n    def update(self, current: float, n_steps: int):\n        \"\"\"Returns adaptively updated KL coefficient, βₜ₊₁.\n        Arguments:\n            current: The current KL value between the newest policy and the initial policy.\n        \"\"\"\n        proportional_error = np.clip(current / self.target - 1, -0.2, 0.2)  # ϵₜ\n        mult = 1 + proportional_error * n_steps / self.horizon\n        self.value *= mult  # βₜ₊₁\n\n\nclass FixedKLController:\n    \"\"\"Fixed KL controller.\"\"\"\n\n    def __init__(self, kl_coef):\n        self.value = kl_coef\n\n    def update(self, current: float, n_steps: int):\n        \"\"\"Returns updated KL coefficient, βₜ₊₁.\n        Arguments:\n            current: The current KL value between the newest policy and the initial policy.\n        \"\"\"\n        pass\n\n\n# CausalLM architectures\n@dataclass\nclass CausalLMOutputWithValue(ModelOutput):\n    loss: Optional[torch.FloatTensor] = None\n    logits: Optional[torch.FloatTensor] = None\n    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None\n    hidden_states: Optional[Tuple[torch.FloatTensor]] = None\n    attentions: Optional[Tuple[torch.FloatTensor]] = None\n    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None\n    value: Optional[torch.FloatTensor] = None\n\n\nclass AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):\n    \"\"\"An `AutoModel` class wrapper for `transformers` causal models that have a\n    language modeling head and a value head\n    \"\"\"\n\n    _auto_model_parent_class = transformers.AutoModelForCausalLM\n    _supported_modules = [\"v_head\"]\n    _supported_args = []\n\n    def __init__(\n            self,\n            base_model: transformers.PreTrainedModel,\n            **kwargs\n    ):\n        super().__init__(base_model)\n        self.v_head = make_head(hf_get_hidden_size(self.base_model.config), 1)\n\n    def forward(\n            self,\n            input_ids: torch.LongTensor = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            past_key_values: Optional[List[torch.FloatTensor]] = None,\n            position_ids: Optional[List[torch.FloatTensor]] = None,\n            head_mask: Optional[torch.Tensor] = None,\n            inputs_embeds: Optional[torch.FloatTensor] = None,\n            use_cache: Optional[bool] = None,\n            output_attentions: Optional[bool] = None,\n            output_hidden_states: Optional[bool] = None,\n            return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, CausalLMOutputWithValue]:\n        forward_kwargs = self.get_compatible_forward_kwargs(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        forward_kwargs[\"output_hidden_states\"] = True\n        forward_kwargs[\"return_dict\"] = True\n\n        outputs = self.base_model(**forward_kwargs)\n        value = self.v_head(outputs.hidden_states[-1]).squeeze(-1)\n\n        if not return_dict:\n            outputs = (outputs.logits,) + outputs[1:] + (value,)\n            return outputs\n\n        return CausalLMOutputWithValue(**outputs, value=value)\n\n    def generate(self, *args, **kwargs) -> Union[ModelOutput, torch.LongTensor]:\n        return self.base_model.generate(*args, **kwargs)\n\n    def state_dict(self, *args, **kwargs):\n        \"\"\"\n        Returns the state dictionary of the model. We add the state dictionary of the value head\n        to the state dictionary of the wrapped model by prepending the key with `v_head.`.\n        \"\"\"\n        base_model_state_dict = self.base_model.state_dict(*args, **kwargs)\n        v_head_state_dict = self.v_head.state_dict(*args, **kwargs)\n        for k, v in v_head_state_dict.items():\n            base_model_state_dict[f\"v_head.{k}\"] = v\n        return base_model_state_dict\n\n    def post_init(self, state_dict):\n        \"\"\"\n        Adds the state dictionary of the value head to the state dictionary of the wrapped model\n        by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the\n        keys of the value head state dictionary.\n        \"\"\"\n        for k in list(state_dict.keys()):\n            if \"v_head.\" in k:\n                state_dict[k.replace(\"v_head.\", \"\")] = state_dict.pop(k)\n        self.v_head.load_state_dict(state_dict, strict=False)\n        del state_dict\n        gc.collect()  # noqa: E702\n\n\nclass AutoModelForCausalLMWithHydraValueHead(AutoModelForCausalLMWithValueHead):\n    _supported_modules = [\"v_head\", \"frozen_head\"]\n    _supported_args = [\"num_layers_unfrozen\"]\n\n    def __init__(\n            self,\n            base_model: transformers.PreTrainedModel,\n            num_layers_unfrozen: int = -1,\n    ):\n        super().__init__(base_model)\n        self.num_layers_unfrozen = num_layers_unfrozen\n        if self.num_layers_unfrozen > 0:\n            config = self.base_model.config\n            branch_class = hf_get_branch_class(config)\n            self.frozen_head = branch_class(\n                self.base_model,\n                num_layers_unfrozen=self.num_layers_unfrozen,\n            ).eval()\n\n        if base_model.config.lora_rank > 0:\n            config = LoraConfig(\n                r=base_model.config.lora_rank,\n                lora_alpha=base_model.config.lora_alpha,\n                target_modules=config.target_modules.split(\",\"),\n                lora_dropout=0.05,\n                bias=base_model.config.lora_train_bias,\n                task_type=config.task_type\n            )\n            self.base_model.enable_input_require_grads()\n            self.base_model = get_peft_model(base_model, config)\n            # convert_to_lora_recursively(base_model, base_model.config.lora_rank, base_model.config.lora_alpha)\n            # lora.mark_only_lora_as_trainable(base_model, base_model.config.lora_train_bias)\n\n    def forward_hydra(\n            self,\n            input_ids: torch.LongTensor = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            past_key_values: Optional[List[torch.FloatTensor]] = None,\n            position_ids: Optional[List[torch.FloatTensor]] = None,\n            head_mask: Optional[torch.Tensor] = None,\n            inputs_embeds: Optional[torch.FloatTensor] = None,\n            use_cache: Optional[bool] = None,\n            output_attentions: Optional[bool] = None,\n            output_hidden_states: Optional[bool] = None,\n            return_dict: Optional[bool] = None,\n    ) -> Union[torch.FloatTensor, CausalLMOutputWithValue]:\n        forward_kwargs = self.get_compatible_forward_kwargs(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        return_dict = forward_kwargs.get(\"return_dict\", True)\n        forward_kwargs[\"return_dict\"] = True\n        forward_kwargs[\"output_hidden_states\"] = True\n\n        outputs = self.forward(**forward_kwargs)\n        # Select the hidden state before the first branching layer\n        input_hidden_state = outputs.hidden_states[-(self.num_layers_unfrozen + 1)]\n\n        output_shape = outputs.hidden_states[-1].size()\n        forward_kwargs.pop(\"input_ids\", None)  # Ignore `input_ids` for branch head\n        forward_kwargs.pop(\"inputs_embeds\", None)  # Ignore `inputs_embeds` for branch head\n        hydra_outputs = self.frozen_head(input_hidden_state, output_shape, **forward_kwargs)\n\n        if not return_dict:\n            return hydra_outputs.logits\n        return hydra_outputs\n\n    @classmethod\n    def from_pretrained(  # noqa: max-complexity\n            cls,\n            pretrained_model_name_or_path: Union[str, transformers.PreTrainedModel],\n            *model_args,\n            **kwargs,\n    ):\n        \"\"\"Instantiate a pretrained pytorch model from a pretrained model configuration.\n        This method is a wrapper around `transformers.PreTrainedModel.from_pretrained`.\n        Please refer to the documentation of `transformers.PreTrainedModel.from_pretrained`\n        for more information.\n\n        Args:\n            pretrained_model_name_or_path (str or `transformers.PreTrainedModel`):\n                The identifier of the pretrained model to load or the pretrained model itself.\n            *model_args (sequence of positional arguments, *optional*):\n                All remaining positional arguments will be passed to the `_auto_model_parent_class`.\n            **kwargs (dict, *optional*):\n                Dictionary of keyword arguments to pass to both the underlying `_auto_model_parent_class`\n                call (e.g. `transformers.AutoModelForCausalLM.from_pretrained`) and the specific\n                instance of the wrapped model.\n\n        NOTE: You must pass in arguments specific to the wrapped model as keyword arguments.\n        \"\"\"\n        if kwargs is not None:\n            wrapped_model_kwargs, from_pretrained_kwargs = cls._split_kwargs(kwargs)\n        else:\n            from_pretrained_kwargs = {}\n            wrapped_model_kwargs = {}\n\n        if isinstance(pretrained_model_name_or_path, str):\n            # Load the base model using the `transformers` AutoClass (e.g. AutoModelForCausalLM)\n            base_model = cls._auto_model_parent_class.from_pretrained(\n                pretrained_model_name_or_path, *model_args, **from_pretrained_kwargs\n            )\n        elif isinstance(pretrained_model_name_or_path, transformers.PreTrainedModel):\n            base_model = pretrained_model_name_or_path\n        else:\n            raise ValueError(\n                f\"Invalid type for `base_model_name_or_path`: {type(pretrained_model_name_or_path)}\"\n                \"Expected `str` or `transformers.PreTrainedModel`.\"\n            )\n        # TODO: add model.resize_token_embeddings(tokenizer.vocab_size)\n        # config = from_pretrained_kwargs.get(\"config\", None)\n        # if config is not None:\n        #     base_model.config.lora_rank = config.train.lora_rank\n        #     base_model.config.lora_alpha = config.train.lora_alpha\n        #     base_model.config.lora_train_bias = config.train.lora_train_bias\n\n        if isinstance(pretrained_model_name_or_path, str):\n            filename = os.path.join(pretrained_model_name_or_path, \"pytorch_model.bin\")\n            sharded_index_filename = os.path.join(pretrained_model_name_or_path, \"pytorch_model.bin.index.json\")\n            is_sharded = False\n\n            if not os.path.exists(filename):\n                try:\n                    filename = hf_hub_download(pretrained_model_name_or_path, \"pytorch_model.bin\")\n                # Sharded\n                except Exception:\n                    if os.path.exists(sharded_index_filename):\n                        index_file_name = sharded_index_filename\n                    else:\n                        index_file_name = hf_hub_download(\n                            pretrained_model_name_or_path,\n                            \"pytorch_model.bin.index.json\",\n                        )\n                    with open(index_file_name, \"r\") as f:\n                        index = json.load(f)\n                    # Collect files containing weights from supported modules\n                    files_to_download = set()\n                    for k, v in index[\"weight_map\"].items():\n                        if any([module in k for module in cls._supported_modules]):\n                            files_to_download.add(v)\n                    is_sharded = True\n\n            if is_sharded:\n                # Merge each shard into a state dict\n                # TODO: Optimize this to avoid wasting RAM\n                state_dict = {}\n                for shard_file in files_to_download:\n                    filename = os.path.join(pretrained_model_name_or_path, shard_file)\n                    # Download if shard file doesn't exist locally\n                    if not os.path.exists(filename):\n                        filename = hf_hub_download(pretrained_model_name_or_path, shard_file)\n                    state_dict.update(torch.load(filename, map_location=\"cpu\"))\n            else:\n                state_dict = torch.load(filename, map_location=\"cpu\")\n        else:\n            state_dict = pretrained_model_name_or_path.state_dict()\n\n        # Check if sft model is LoRA checkpoint, load the state dict into model\n        is_lora_checkpoint = False\n        for key in state_dict.keys():\n            if \"lora\" in key:\n                is_lora_checkpoint = True\n                break\n\n        if is_lora_checkpoint:\n            # base_model = SFTModelWithLoRA(base_model.config, base_model)\n            res = base_model.load_state_dict(state_dict, strict=False)\n\n        model = cls(base_model, **wrapped_model_kwargs)\n\n        model.post_init(state_dict=state_dict)\n\n        return model\n\n\nclass ModelBranch(transformers.PreTrainedModel):\n    \"\"\"Implements the frozen upper trunk of the pretrained reference model used\n    when computing the PPO KL-divergence penalty.\n    \"\"\"\n\n    def __init__(\n            self,\n            base_model: transformers.PreTrainedModel,\n            *,\n            num_layers_unfrozen: int,\n    ):\n        \"\"\"\n        Args:\n            base_model (transformers.PreTrainedModel): The pretrained model to extract upper trunk from\n            num_layers_unfrozen (int): The number of trainable layers\n        \"\"\"\n        super().__init__(base_model.config)\n\n        # The branch is defined by the last `num_layers_unfrozen` layers of the pretrained model\n        decoder_blocks = deepcopy(hf_get_decoder_blocks(base_model))\n        self.decoder_blocks = nn.ModuleList(list(decoder_blocks)[-num_layers_unfrozen:])\n        self.final_norm = deepcopy(hf_get_decoder_final_norm(base_model))\n        self.lm_head = deepcopy(hf_get_lm_head(base_model))\n\n        self.hidden_size = hf_get_hidden_size(self.config)\n        self.model_parallel = False\n        self.device_map = None\n        self.last_device = None\n        self.gradient_checkpointing = False\n\n        # Freeze the entire branch\n        for parameter in self.parameters():\n            parameter.requires_grad_(False)\n\n\nclass GPTModelBranch(ModelBranch):\n    def forward(  # noqa: max-complexity\n            self,\n            hidden_states: torch.Tensor,  # Takes as input hidden_states instead of input_ids\n            output_shape: torch.Tensor,  # output_size given by main trunk\n            past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n            attention_mask: Optional[torch.FloatTensor] = None,\n            position_ids: Optional[torch.LongTensor] = None,\n            head_mask: Optional[torch.FloatTensor] = None,\n            encoder_hidden_states: Optional[torch.Tensor] = None,\n            encoder_attention_mask: Optional[torch.FloatTensor] = None,\n            use_cache: Optional[bool] = None,\n            output_attentions: Optional[bool] = None,\n            output_hidden_states: Optional[bool] = None,\n            return_dict: Optional[bool] = False,\n    ) -> Union[Tuple, CausalLMOutputWithValue]:\n        \"\"\"Reference:\n        https://github.com/huggingface/transformers/blob/2411f0e465e761790879e605a4256f3d4afb7f82/src/transformers/models/gpt2/modeling_gpt2.py#L743  # noqa: E501\n        \"\"\"\n        batch_size = hidden_states.size()[0]\n\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        device = hidden_states.device\n\n        if past_key_values is None:\n            past_key_values = tuple([None] * len(self.decoder_blocks))\n\n        if attention_mask is not None:\n            if batch_size <= 0:\n                raise ValueError(\"batch_size has to be defined and > 0\")\n            attention_mask = attention_mask.view(batch_size, -1)\n            attention_mask = attention_mask[:, None, None, :]\n            attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility\n            attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min\n\n        if self.config.add_cross_attention and encoder_hidden_states is not None:\n            (\n                encoder_batch_size,\n                encoder_sequence_length,\n                _,\n            ) = encoder_hidden_states.size()\n            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)\n            if encoder_attention_mask is None:\n                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)\n            encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)\n        else:\n            encoder_attention_mask = None\n\n        head_mask = self.get_head_mask(head_mask, hf_get_num_hidden_layers(self.config))\n\n        presents = () if use_cache else None\n        all_self_attentions = () if output_attentions else None\n        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None\n        all_hidden_states = () if output_hidden_states else None\n        for i, (block, layer_past) in enumerate(zip(self.decoder_blocks, past_key_values)):\n            if self.model_parallel:\n                torch.cuda.set_device(hidden_states.device)\n                if layer_past is not None:\n                    layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)\n                if attention_mask is not None:\n                    attention_mask = attention_mask.to(hidden_states.device)\n                if isinstance(head_mask, torch.Tensor):\n                    head_mask = head_mask.to(hidden_states.device)\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_states,)\n\n            # Assumes we are never training the branch\n            block_params = inspect.getfullargspec(block.forward).args\n            if \"encoder_hidden_states\" in block_params:\n                outputs = block(\n                    hidden_states,\n                    layer_past=layer_past,\n                    attention_mask=attention_mask,\n                    head_mask=head_mask[i],\n                    encoder_hidden_states=encoder_hidden_states,\n                    encoder_attention_mask=encoder_attention_mask,\n                    use_cache=use_cache,\n                    output_attentions=output_attentions,\n                )\n            else:\n                outputs = block(\n                    hidden_states,\n                    layer_past=layer_past,\n                    attention_mask=attention_mask,\n                    head_mask=head_mask[i],\n                    use_cache=use_cache,\n                    output_attentions=output_attentions,\n                )\n\n            hidden_states = outputs[0]\n            if use_cache is True:\n                presents = presents + (outputs[1],)\n\n            if output_attentions:\n                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)\n                if self.config.add_cross_attention:\n                    all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)\n\n            if self.model_parallel:\n                for k, v in self.device_map.items():\n                    if i == v[-1] and \"cuda:\" + str(k) != self.last_device:\n                        hidden_states = hidden_states.to(\"cuda:\" + str(k + 1))\n\n        hidden_states = self.final_norm(hidden_states)\n\n        hidden_states = hidden_states.view(output_shape)\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_states,)\n\n        if self.model_parallel:\n            torch.cuda.set_device(self.transformer.first_device)\n            hidden_states = hidden_states.to(self.lm_head.weight.device)\n\n        lm_logits = self.lm_head(hidden_states)\n\n        if not return_dict:\n            outputs = (lm_logits,) + (None,) + (None,)\n            return outputs\n\n        return CausalLMOutputWithValue(\n            logits=lm_logits,\n            past_key_values=presents,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attentions,\n            cross_attentions=all_cross_attentions,\n        )\n\n\nclass OPTModelBranch(ModelBranch):\n    def forward(  # noqa: max-complexity\n            self,\n            hidden_states: torch.Tensor,\n            output_shape: torch.Tensor,\n            past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n            attention_mask: Optional[torch.FloatTensor] = None,\n            position_ids: Optional[torch.LongTensor] = None,\n            head_mask: Optional[torch.FloatTensor] = None,\n            encoder_hidden_states: Optional[torch.Tensor] = None,\n            encoder_attention_mask: Optional[torch.FloatTensor] = None,\n            use_cache: Optional[bool] = None,\n            output_attentions: Optional[bool] = None,\n            output_hidden_states: Optional[bool] = None,\n            return_dict: Optional[bool] = False,\n    ) -> Union[Tuple, CausalLMOutputWithValue]:\n        \"\"\"Reference:\n        https://github.com/huggingface/transformers/blob/bdb84e2bada3658f99c6a81c963ec562f8485151/src/transformers/models/opt/modeling_opt.py#L840  # noqa: E501\n        \"\"\"\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0\n\n        if attention_mask is None:\n            attention_mask = torch.ones(hidden_states.shape[:2], dtype=torch.bool, device=hidden_states.device)\n\n        input_shape = hidden_states.size()[:-1]\n        combined_attention_mask = None\n        if input_shape[-1] > 1:\n            combined_attention_mask = modeling_opt._make_causal_mask(\n                input_shape,\n                hidden_states.dtype,\n                past_key_values_length=past_key_values_length,\n            ).to(hidden_states.device)\n\n        if attention_mask is not None:\n            expanded_attn_mask = modeling_opt._expand_mask(\n                attention_mask, hidden_states.dtype, tgt_len=input_shape[-1]\n            ).to(hidden_states.device)\n            combined_attention_mask = (\n                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask\n            )\n        attention_mask = combined_attention_mask\n\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        next_decoder_cache = () if use_cache else None\n\n        for attn_mask, mask_name in zip([head_mask], [\"head_mask\"]):\n            if attn_mask is not None:\n                if attn_mask.size()[0] != (len(self.decoder_blocks)):\n                    raise ValueError(\n                        f\"The `{mask_name}` should be specified for {len(self.decoder_blocks)} layers, but it is for\"\n                        f\" {head_mask.size()[0]}.\"\n                    )\n\n        for idx, decoder_layer in enumerate(self.decoder_blocks):\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            past_key_value = past_key_values[idx] if past_key_values is not None else None\n\n            layer_outputs = decoder_layer(\n                hidden_states,\n                past_key_value=past_key_value,\n                attention_mask=attention_mask,\n                layer_head_mask=(head_mask[idx] if head_mask is not None else None),\n                use_cache=use_cache,\n                output_attentions=output_attentions,\n            )\n\n            hidden_states = layer_outputs[0]\n\n            if use_cache:\n                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        if self.final_norm is not None:\n            hidden_states = self.final_norm(hidden_states)\n\n        # TODO: Add output projection support\n        # https://github.com/huggingface/transformers/blob/699e90437f984d69ad3c9b891dd2e9d0fc2cffe4/src/transformers/models/opt/modeling_opt.py#L499  # noqa: E501\n        # if self.project_out is not None:\n        #     hidden_states = self.project_out(hidden_states)\n\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = next_decoder_cache if use_cache else None\n\n        lm_logits = self.lm_head(hidden_states).contiguous()\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [\n                    lm_logits,\n                    hidden_states,\n                    next_cache,\n                    all_hidden_states,\n                    all_self_attns,\n                ]\n                if v is not None\n            )\n\n        return CausalLMOutputWithValue(\n            logits=lm_logits,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n\n\nclass BloomModelBranch(ModelBranch):\n    def forward(  # noqa: max-complexity\n            self,\n            hidden_states: torch.Tensor,  # Takes as input hidden_states instead of input_ids\n            output_shape: torch.Tensor,\n            past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n            attention_mask: Optional[torch.FloatTensor] = None,\n            position_ids: Optional[torch.LongTensor] = None,\n            head_mask: Optional[torch.FloatTensor] = None,\n            encoder_hidden_states: Optional[torch.Tensor] = None,\n            encoder_attention_mask: Optional[torch.FloatTensor] = None,\n            use_cache: Optional[bool] = None,\n            output_attentions: Optional[bool] = None,\n            output_hidden_states: Optional[bool] = None,\n            return_dict: Optional[bool] = False,\n    ) -> Union[Tuple, CausalLMOutputWithValue]:\n        \"\"\"Reference:\n        https://github.com/huggingface/transformers/blob/2411f0e465e761790879e605a4256f3d4afb7f82/src/transformers/models/bloom/modeling_bloom.py#L623  # noqa: E501\n        \"\"\"\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        batch_size, seq_length = hidden_states.shape[:2]\n\n        if past_key_values is None:\n            past_key_values = tuple([None] * len(self.decoder_blocks))\n\n        head_mask = self.get_head_mask(head_mask, hf_get_num_hidden_layers(self.config))\n\n        presents = () if use_cache else None\n        all_self_attentions = () if output_attentions else None\n        all_hidden_states = () if output_hidden_states else None\n\n        seq_length_with_past = seq_length\n        past_key_values_length = 0\n        if past_key_values[0] is not None:\n            past_key_values_length = past_key_values[0][0].shape[2]\n            seq_length_with_past = seq_length_with_past + past_key_values_length\n        if attention_mask is None:\n            attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)\n        else:\n            attention_mask = attention_mask.to(hidden_states.device)\n\n        alibi = modeling_bloom.build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype)\n\n        combined_attention_mask = None\n        device = attention_mask.device\n        input_shape = (batch_size, seq_length)\n        _, src_length = input_shape\n\n        if src_length > 1:\n            combined_attention_mask = modeling_bloom._make_causal_mask(\n                input_shape,\n                device=device,\n                past_key_values_length=past_key_values_length,\n            )\n\n        expanded_attn_mask = modeling_bloom._expand_mask(attention_mask, tgt_length=src_length)\n        combined_attention_mask = (\n            expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask\n        )\n        causal_mask = combined_attention_mask\n\n        for i, (block, layer_past) in enumerate(zip(self.decoder_blocks, past_key_values)):\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_states,)\n\n            outputs = block(\n                hidden_states,\n                layer_past=layer_past,\n                attention_mask=causal_mask,\n                head_mask=head_mask[i],\n                use_cache=use_cache,\n                output_attentions=output_attentions,\n                alibi=alibi,\n            )\n\n            hidden_states = outputs[0]\n            if use_cache is True:\n                presents = presents + (outputs[1],)\n\n            if output_attentions:\n                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)\n\n        hidden_states = self.final_norm(hidden_states)\n\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_states,)\n\n        lm_logits = self.lm_head(hidden_states)\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [\n                    lm_logits,\n                    hidden_states,\n                    presents,\n                    all_hidden_states,\n                    all_self_attentions,\n                ]\n                if v is not None\n            )\n\n        return CausalLMOutputWithValue(\n            logits=lm_logits,\n            past_key_values=presents,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attentions,\n        )\n\n\n# Seq2Seq architectures\n@dataclass\nclass Seq2SeqLMOutputWithValue(ModelOutput):\n    loss: Optional[torch.FloatTensor] = None\n    logits: Optional[torch.FloatTensor] = None\n    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None\n    decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None\n    decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None\n    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None\n    encoder_last_hidden_state: Optional[torch.FloatTensor] = None\n    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None\n    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None\n    value: Optional[torch.FloatTensor] = None\n\n\nclass AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper):\n    \"\"\"An `AutoModel` class wrapper for `transformers` sequence-to-sequence\n    models that have a language modeling head and a value head\n    \"\"\"\n\n    _auto_model_parent_class = transformers.AutoModelForSeq2SeqLM\n    _supported_modules = [\"v_head\"]\n    _supported_args = []\n\n    def __init__(\n            self,\n            base_model: transformers.PreTrainedModel,\n            **kwargs\n    ):\n        super().__init__(base_model)\n        self.v_head = make_head(hf_get_hidden_size(self.base_model.config), 1)\n\n    def forward(\n            self,\n            input_ids: Optional[torch.LongTensor] = None,\n            attention_mask: Optional[torch.FloatTensor] = None,\n            position_ids: Optional[torch.FloatTensor] = None,\n            decoder_input_ids: Optional[torch.LongTensor] = None,\n            decoder_attention_mask: Optional[torch.FloatTensor] = None,\n            encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,\n            past_key_values: Optional[Tuple[torch.FloatTensor]] = None,\n            inputs_embeds: Optional[torch.FloatTensor] = None,\n            decoder_inputs_embeds: Optional[torch.FloatTensor] = None,\n            head_mask: Optional[torch.FloatTensor] = None,\n            decoder_head_mask: Optional[torch.FloatTensor] = None,\n            cross_attn_head_mask: Optional[torch.FloatTensor] = None,\n            use_cache: Optional[bool] = None,\n            output_attentions: Optional[bool] = True,\n            output_hidden_states: Optional[bool] = True,\n            return_dict: Optional[bool] = None,\n    ) -> Seq2SeqLMOutputWithValue:\n        forward_kwargs = self.get_compatible_forward_kwargs(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            decoder_input_ids=decoder_input_ids,\n            decoder_attention_mask=decoder_attention_mask,\n            encoder_outputs=encoder_outputs,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            decoder_inputs_embeds=decoder_inputs_embeds,\n            head_mask=head_mask,\n            decoder_head_mask=decoder_head_mask,\n            cross_attn_head_mask=cross_attn_head_mask,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        forward_kwargs[\"output_hidden_states\"] = True\n        forward_kwargs[\"return_dict\"] = True\n\n        # print(f\"forward_kwargs: {forward_kwargs.keys()}, input_ids shape: {forward_kwargs['input_ids'].shape}, attention_mask shape: {forward_kwargs['attention_mask'].shape}\")\n        outputs = self.base_model(**forward_kwargs)\n        # print(f\"outputs: {outputs.keys()}\")\n        try:\n            hidden_states = outputs.mems\n        except KeyError:\n            hidden_states = outputs.hidden_states\n        last_hidden_state = hidden_states[-1]\n        value = self.v_head(last_hidden_state).squeeze(-1)\n\n        return Seq2SeqLMOutputWithValue(loss=outputs.loss, logits=outputs.logits, decoder_hidden_states=hidden_states, value=value)\n\n    def generate(self, *args, **kwargs) -> Union[ModelOutput, torch.LongTensor]:\n        return self.base_model.generate(*args, **kwargs)\n\n    def state_dict(self, *args, **kwargs):\n        \"\"\"\n        Returns the state dictionary of the model. We add the state dictionary of the value head\n        to the state dictionary of the wrapped model by prepending the key with `v_head.`.\n        \"\"\"\n        base_model_state_dict = self.base_model.state_dict(*args, **kwargs)\n        v_head_state_dict = self.v_head.state_dict(*args, **kwargs)\n        for k, v in v_head_state_dict.items():\n            base_model_state_dict[f\"v_head.{k}\"] = v\n        return base_model_state_dict\n\n    def post_init(self, state_dict):\n        \"\"\"\n        We add the state dictionary of the value head to the state dictionary of the wrapped model\n        by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the\n        keys of the value head state dictionary.\n        \"\"\"\n        for k in list(state_dict.keys()):\n            if \"v_head.\" in k:\n                state_dict[k.replace(\"v_head.\", \"\")] = state_dict.pop(k)\n        self.v_head.load_state_dict(state_dict, strict=False)\n        del state_dict\n        gc.collect()  # noqa: E702\n\n\nclass AutoModelForSeq2SeqLMWithHydraValueHead(AutoModelForSeq2SeqLMWithValueHead):\n    _supported_modules = [\"v_head\", \"frozen_head\"]\n    _supported_args = [\"num_layers_unfrozen\"]\n\n    def __init__(\n            self,\n            base_model: transformers.PreTrainedModel,\n            num_layers_unfrozen: int = -1,\n    ):\n        super().__init__(base_model=base_model)\n        self.num_layers_unfrozen = num_layers_unfrozen\n        if self.num_layers_unfrozen > 0:\n            branch_class = T5Branch  # TODO: Add support for other model branches\n            self.frozen_head = branch_class(\n                self.base_model,\n                num_layers_unfrozen=self.num_layers_unfrozen,\n            ).eval()\n\n        # if base_model.config.lora_rank > 0:\n        #     convert_to_lora_recursively(base_model, base_model.config.lora_rank, base_model.config.lora_alpha)\n        #     lora.mark_only_lora_as_trainable(base_model, base_model.config.lora_train_bias)\n\n    def forward_hydra(\n            self,\n            input_ids: Optional[torch.LongTensor] = None,\n            attention_mask: Optional[torch.FloatTensor] = None,\n            decoder_input_ids: Optional[torch.LongTensor] = None,\n            decoder_attention_mask: Optional[torch.FloatTensor] = None,\n            encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,\n            past_key_values: Optional[Tuple[torch.FloatTensor]] = None,\n            inputs_embeds: Optional[torch.FloatTensor] = None,\n            decoder_inputs_embeds: Optional[torch.FloatTensor] = None,\n            head_mask: Optional[torch.FloatTensor] = None,\n            decoder_head_mask: Optional[torch.FloatTensor] = None,\n            cross_attn_head_mask: Optional[torch.FloatTensor] = None,\n            use_cache: Optional[bool] = None,\n            output_attentions: Optional[bool] = None,\n            output_hidden_states: Optional[bool] = None,\n            return_dict: Optional[bool] = None,\n    ) -> Seq2SeqLMOutputWithValue:\n        forward_kwargs = self.get_compatible_forward_kwargs(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            decoder_input_ids=decoder_input_ids,\n            decoder_attention_mask=decoder_attention_mask,\n            encoder_outputs=encoder_outputs,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            decoder_inputs_embeds=decoder_inputs_embeds,\n            head_mask=head_mask,\n            decoder_head_mask=decoder_head_mask,\n            cross_attn_head_mask=cross_attn_head_mask,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        return_dict = forward_kwargs.get(\"return_dict\", True)\n        forward_kwargs[\"output_hidden_states\"] = True\n        forward_kwargs[\"return_dict\"] = True\n\n        outputs = self.forward(**forward_kwargs)\n        # Select the hidden state before the first branching layer\n        input_hidden_state = outputs.decoder_hidden_states[-(self.num_layers_unfrozen + 1)]\n        hydra_outputs = self.frozen_head(\n            hidden_states=input_hidden_state,\n            attention_mask=decoder_attention_mask,\n            encoder_hidden_states=outputs.encoder_last_hidden_state,\n            encoder_attention_mask=attention_mask,\n            use_cache=False,\n            output_attentions=False,\n            output_hidden_states=True,\n            return_dict=return_dict,\n        )\n\n        if not return_dict:\n            return hydra_outputs.logits\n        return hydra_outputs\n\n    @classmethod\n    def from_pretrained(  # noqa: max-complexity\n            cls,\n            pretrained_model_name_or_path: Union[str, transformers.PreTrainedModel],\n            *model_args,\n            **kwargs,\n    ):\n        \"\"\"Instantiate a pretrained pytorch model from a pretrained model configuration.\n        This method is a wrapper around `transformers.PreTrainedModel.from_pretrained`.\n        Please refer to the documentation of `transformers.PreTrainedModel.from_pretrained`\n        for more information.\n\n        Args:\n            pretrained_model_name_or_path (str or `transformers.PreTrainedModel`):\n                The identifier of the pretrained model to load or the pretrained model itself.\n            *model_args (sequence of positional arguments, *optional*):\n                All remaining positional arguments will be passed to the `_auto_model_parent_class`.\n            **kwargs (dict, *optional*):\n                Dictionary of keyword arguments to pass to both the underlying `_auto_model_parent_class`\n                call (e.g. `transformers.AutoModelForCausalLM.from_pretrained`) and the specific\n                instance of the wrapped model.\n\n        NOTE: You must pass in arguments specific to the wrapped model as keyword arguments.\n        \"\"\"\n        if kwargs is not None:\n            wrapped_model_kwargs, from_pretrained_kwargs = cls._split_kwargs(kwargs)\n        else:\n            from_pretrained_kwargs = {}\n            wrapped_model_kwargs = {}\n\n        if isinstance(pretrained_model_name_or_path, str):\n            # Load the base model using the `transformers` AutoClass (e.g. AutoModelForCausalLM)\n            base_model = cls._auto_model_parent_class.from_pretrained(\n                pretrained_model_name_or_path, *model_args, **from_pretrained_kwargs\n            )\n        elif isinstance(pretrained_model_name_or_path, transformers.PreTrainedModel):\n            base_model = pretrained_model_name_or_path\n        else:\n            raise ValueError(\n                f\"Invalid type for `base_model_name_or_path`: {type(pretrained_model_name_or_path)}\"\n                \"Expected `str` or `transformers.PreTrainedModel`.\"\n            )\n\n        config = from_pretrained_kwargs.get(\"config\", None)\n        # if config is not None:\n        #     base_model.config.lora_rank = config.train.lora_rank\n        #     base_model.config.lora_alpha = config.train.lora_alpha\n        #     base_model.config.lora_train_bias = config.train.lora_train_bias\n\n        if isinstance(pretrained_model_name_or_path, str):\n            filename = os.path.join(pretrained_model_name_or_path, \"pytorch_model.bin\")\n            sharded_index_filename = os.path.join(pretrained_model_name_or_path, \"pytorch_model.bin.index.json\")\n            is_sharded = False\n\n            if not os.path.exists(filename):\n                try:\n                    filename = hf_hub_download(pretrained_model_name_or_path, \"pytorch_model.bin\")\n                # Sharded\n                except Exception:\n                    if os.path.exists(sharded_index_filename):\n                        index_file_name = sharded_index_filename\n                    else:\n                        index_file_name = hf_hub_download(\n                            pretrained_model_name_or_path,\n                            \"pytorch_model.bin.index.json\",\n                        )\n                    with open(index_file_name, \"r\") as f:\n                        index = json.load(f)\n                    # Collect files containing weights from supported modules\n                    files_to_download = set()\n                    for k, v in index[\"weight_map\"].items():\n                        if any([module in k for module in cls._supported_modules]):\n                            files_to_download.add(v)\n                    is_sharded = True\n\n            if is_sharded:\n                # Merge each shard into a state dict\n                # TODO: Optimize this to avoid wasting RAM\n                state_dict = {}\n                for shard_file in files_to_download:\n                    filename = os.path.join(pretrained_model_name_or_path, shard_file)\n                    # Download if shard file doesn't exist locally\n                    if not os.path.exists(filename):\n                        filename = hf_hub_download(pretrained_model_name_or_path, shard_file)\n                    state_dict.update(torch.load(filename, map_location=\"cpu\"))\n            else:\n                state_dict = torch.load(filename, map_location=\"cpu\")\n        else:\n            state_dict = pretrained_model_name_or_path.state_dict()\n\n        # Check if sft model is LoRA checkpoint, load the state dict into model\n        is_lora_checkpoint = False\n        for key in state_dict.keys():\n            if \"lora\" in key:\n                is_lora_checkpoint = True\n                break\n\n        if is_lora_checkpoint:\n            # base_model = SFTModelWithLoRA(base_model.config, base_model)\n            res = base_model.load_state_dict(state_dict, strict=False)\n\n        model = cls(base_model, **wrapped_model_kwargs)\n\n        model.post_init(state_dict=state_dict)\n\n        return model\n\n\nclass T5Branch(ModelBranch):\n    \"\"\"Decoder only T5 branch\"\"\"\n\n    def __init__(\n            self,\n            base_model: transformers.PreTrainedModel,\n            *,\n            num_layers_unfrozen: int,\n    ):\n        super().__init__(base_model, num_layers_unfrozen=num_layers_unfrozen)\n        self.dropout = hf_get_decoder(base_model).dropout\n        self.is_decoder = True\n\n    def forward(  # noqa: max-complexity\n            self,\n            hidden_states: Optional[torch.LongTensor] = None,\n            attention_mask: Optional[torch.LongTensor] = None,\n            encoder_hidden_states: Optional[torch.Tensor] = None,\n            encoder_attention_mask: Optional[torch.FloatTensor] = None,\n            use_cache: Optional[bool] = None,\n            output_attentions: Optional[bool] = None,\n            output_hidden_states: Optional[bool] = None,\n            return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, Seq2SeqLMOutputWithValue]:\n        \"\"\"Reference:\n        https://github.com/huggingface/transformers/blob/bc21aaca789f1a366c05e8b5e111632944886393/src/transformers/models/t5/modeling_t5.py#L899  # noqa: E501\n        \"\"\"\n        batch_size, seq_length = hidden_states.shape[:2]\n        input_shape = (batch_size, seq_length)\n\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        if attention_mask is None:\n            attention_mask = torch.ones(batch_size, seq_length, device=hidden_states.device)\n        if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:\n            encoder_seq_length = encoder_hidden_states.shape[1]\n            encoder_attention_mask = torch.ones(\n                batch_size, encoder_seq_length, device=hidden_states.device, dtype=torch.long\n            )\n\n        extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)\n\n        if self.is_decoder and encoder_hidden_states is not None:\n            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()\n            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)\n            if encoder_attention_mask is None:\n                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=hidden_states.device)\n            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)\n        else:\n            encoder_extended_attention_mask = None\n\n        all_hidden_states = () if output_hidden_states else None\n        all_attentions = () if output_attentions else None\n        position_bias = None\n        encoder_decoder_position_bias = None\n\n        for _, layer_module in enumerate(self.decoder_blocks):\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_states,)\n\n            layer_outputs = layer_module(\n                hidden_states,\n                attention_mask=extended_attention_mask,\n                position_bias=position_bias,\n                encoder_hidden_states=encoder_hidden_states,\n                encoder_attention_mask=encoder_extended_attention_mask,\n                encoder_decoder_position_bias=encoder_decoder_position_bias,\n                use_cache=use_cache,\n                output_attentions=output_attentions,\n            )\n\n            if use_cache is False:\n                layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]\n\n            hidden_states, present_key_value_state = layer_outputs[:2]\n\n            position_bias = layer_outputs[2]\n            if self.is_decoder and encoder_hidden_states is not None:\n                encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]\n\n            if output_attentions:\n                all_attentions = all_attentions + (layer_outputs[3],)\n\n        hidden_states = self.final_norm(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n\n        # Add last layer\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_states,)\n\n        sequence_output = hidden_states\n\n        if self.config.tie_word_embeddings:\n            # Rescale output before projecting on vocab\n            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586  # noqa: E501\n            sequence_output = sequence_output * (self.config.d_model**-0.5)\n\n        lm_logits = self.lm_head(sequence_output)\n\n        if not return_dict:\n            return (lm_logits,)\n\n        return Seq2SeqLMOutputWithValue(\n            logits=lm_logits,\n            decoder_hidden_states=all_hidden_states,\n            decoder_attentions=all_attentions,\n        )\n\n\n# Branch class utils\ndef hf_get_branch_class(\n        config: transformers.PretrainedConfig,\n) -> \"ModelBranch\":\n    \"\"\"Returns the model branch class for the given config.\"\"\"\n    gpt_branch_supported_archs = [\n        \"GPTJForCausalLM\",\n        \"GPT2LMHeadModel\",\n        \"GPTNeoForCausalLM\",\n        \"GPTNeoXForCausalLM\",\n    ]\n    opt_branch_supported_archs = [\"OPTForCausalLM\"]\n    bloom_branch_supported_archs = [\"BloomModel\", \"BloomForCausalLM\"]\n    arch = config.architectures[0]\n    if arch in gpt_branch_supported_archs:\n        return GPTModelBranch\n    elif arch in opt_branch_supported_archs:\n        return OPTModelBranch\n    elif arch in bloom_branch_supported_archs:\n        return BloomModelBranch\n    else:\n        all_supported_archs = sum(\n            [\n                gpt_branch_supported_archs,\n                opt_branch_supported_archs,\n                bloom_branch_supported_archs,\n            ],\n            [],\n        )\n        raise ValueError(\n            f\"Unsupported architecture: `{arch}`. The following architectures are \"\n            f\"available for model branching:\\n{all_supported_archs}\"\n        )\n"
  },
  {
    "path": "src/models/reward.py",
    "content": "\nimport torch\nfrom torch import nn\nfrom transformers.modeling_utils import PreTrainedModel\n\n\nclass RewardModel(PreTrainedModel):\n    supports_gradient_checkpointing = True\n\n    def __init__(self, config, model, tokenizer):\n        super().__init__(config)\n        self.config = config\n        self.model_type = config.model_type\n        self.pad_id = tokenizer.pad_token_id\n        self.transformer = model\n        # set the device of `v_head` the same as the device of the last component of `model`\n        st = model.state_dict()\n        device = st[list(st.keys())[-1]].device\n        self.v_head = nn.Linear(config.hidden_size, 1,\n                                bias=False,\n                                device=device)\n        # self.loss_fn = PairWiseLoss()\n\n    def gradient_checkpointing_enable(self):\n        self.transformer.gradient_checkpointing_enable()\n\n    def gradient_checkpointing_disable(self):\n        self.transformer.gradient_checkpointing_disable()\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, PreTrainedModel):\n            module.gradient_checkpointing = value\n\n    def reward(\n            self,\n            input_ids=None,\n            attention_mask=None,\n            position_ids=None,\n            use_cache=None,\n    ):\n        batch_size = input_ids.shape[0]\n        transformer_outputs = self.transformer(input_ids, attention_mask=attention_mask, position_ids=position_ids,\n                                               use_cache=use_cache)\n        if self.model_type == \"glm\":\n            hidden_states = transformer_outputs.mems[-1]\n        elif self.model_type == \"chatglm\":\n            hidden_states = transformer_outputs[0].permute(1, 0, 2)\n        elif self.model_type == \"pangu\":\n            hidden_states = transformer_outputs[0]\n            hidden_states = hidden_states.squeeze(1)\n        else:\n            hidden_states = transformer_outputs[0]\n\n        assert len(hidden_states.shape) == 3\n\n        values = self.v_head(hidden_states).squeeze(-1)\n\n        # outputs = self.body(sequences, attention_mask=attention_mask)\n        # last_hidden_states = outputs['last_hidden_state']\n        # values = self.value_head(last_hidden_states)[:, :-1]\n\n        rewards = None\n        # rewards = values.mean(dim=-1)\n        # if len(rewards.shape) == 2:\n        #     rewards = rewards.squeeze(1)    # ensure shape is (B)\n        #\n        # assert len(rewards.shape) == 1 and rewards.shape[0] == batch_size\n\n        return values, rewards\n\n    def forward(\n            self,\n            chosen_input_ids,\n            chosen_attention_mask=None,\n            chosen_position_ids=None,\n            rejected_input_ids=None,\n            rejected_attention_mask=None,\n            rejected_position_ids=None,\n            use_cache=None,\n            **kwargs\n    ):\n        bs = chosen_input_ids.shape[0]\n        chosen_end_scores = []\n        rejected_end_scores = []\n\n        # compute reward for chosen inputs\n        chosen_values, chosen_reward = self.reward(chosen_input_ids, attention_mask=chosen_attention_mask,\n                                                   position_ids=chosen_position_ids, use_cache=use_cache)\n        if len(chosen_input_ids.shape) == 3:\n            chosen_input_ids = chosen_input_ids.squeeze(1)\n\n        # compute reward for rejected inputs if it is not none\n        if rejected_input_ids is not None:\n            reject_values, reject_reward = self.reward(rejected_input_ids, attention_mask=rejected_attention_mask,\n                                                       position_ids=rejected_position_ids, use_cache=use_cache)\n            # loss = self.loss_fn(chosen_reward, reject_reward)\n\n            if len(rejected_input_ids.shape) == 3:\n                rejected_input_ids = rejected_input_ids.squeeze(1)\n            loss = 0\n            for i in range(bs):\n                # Check if there is any padding otherwise take length of sequence\n                c_inds = (chosen_input_ids[i] == self.pad_id).nonzero()\n                c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen_input_ids.shape[1]\n                r_inds = (rejected_input_ids[i] == self.pad_id).nonzero()\n                r_ind = r_inds[0].item() if len(r_inds) > 0 else rejected_input_ids.shape[1]\n                end_ind = max(c_ind, r_ind)\n\n                # Retrieve first index where trajectories diverge\n                divergence_ind = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()[0]\n                assert divergence_ind > 0\n\n                # Index into the correct rewards\n                c_truncated_reward = chosen_values[i][divergence_ind:end_ind]\n                r_truncated_reward = reject_values[i][divergence_ind:end_ind]\n\n                # Use the last non-padding token output as reward score\n                chosen_end_scores.append(chosen_values[i][c_ind-1])\n                rejected_end_scores.append(reject_values[i][r_ind-1])\n\n                # Compute loss\n                loss += -torch.log(torch.sigmoid(c_truncated_reward - r_truncated_reward)).mean()\n            loss = loss / bs\n            rejected_end_scores = torch.stack(rejected_end_scores)\n        else:\n            reject_values = None\n            rejected_end_scores = None\n            loss = None\n            for i in range(bs):\n                c_inds = (chosen_input_ids[i] == self.pad_id).nonzero()\n                c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen_input_ids.shape[1]\n                chosen_end_scores.append(chosen_values[i, c_ind - 1])\n\n        chosen_end_scores = torch.stack(chosen_end_scores)\n\n        return {\n            \"loss\": loss,\n            \"chosen_values\": chosen_values,\n            \"chosen_reward\": chosen_end_scores,\n            \"reject_values\": reject_values,\n            \"reject_reward\": rejected_end_scores\n        }\n"
  },
  {
    "path": "src/models/rlhf_engine.py",
    "content": "# Copyright (c) Microsoft Corporation.\n# SPDX-License-Identifier: Apache-2.0\n\nimport os\nimport time\nimport torch\nimport deepspeed\nimport math\nimport json\n\nfrom deepspeed.ops.adam import FusedAdam\nfrom deepspeed.ops.adam import DeepSpeedCPUAdam\nfrom transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, get_scheduler\nfrom transformers import AutoConfig, AutoModel\nfrom transformers.deepspeed import HfDeepSpeedConfig\nfrom peft import (\n    prepare_model_for_kbit_training,\n    LoraConfig,\n    get_peft_model\n)\n\nfrom src.utils.config import get_train_ds_config, get_eval_ds_config\n# from utils.module.lora import convert_linear_layer_to_lora, only_optimize_lora_parameters\nfrom src.models import RewardModel\nfrom src.utils.logger import logger, RESOURCE_PATH\nfrom src.utils.modeling_utils import get_optimizer_grouped_parameters\n\"\"\"\nTODOs:\n  * support HF models for critic (for debugging), must be a previously saved ckpt from step-2\n  * determine ds_config/zero_stage based on model size, gpu style, world size, etc\n    - get model size by creating simple meta model\n    - 1.3b: zero-2 for actor/ref models, zero-0 for others\n    - 13b+: zero-3 for all models\n\"\"\"\n\n\ndef log_init(model_name, rank, stime=None):\n    if rank == 0:\n        tag = \"start\" if stime is None else \"end\"\n        suffix = \"ing\" if stime is None else \"ed\"\n        duration = \"\"\n        if stime is not None:\n            duration = \"(duration: {:.2f}s)\".format(time.time() - stime)\n        logger.info(f\"[{tag}] Initializ{suffix} {model_name} Model [{tag}] {duration}\")\n\n        return time.time()\n\n\ndef create_hf_model(model_class,\n                    model_name_or_path,\n                    tokenizer,\n                    ds_config=None,\n                    rlhf_training=False,\n                    disable_dropout=False):\n    model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)\n    if disable_dropout:\n        model_config.dropout = 0.0\n    # Note: dschf is defined in function scope to avoid global effects\n    # https://huggingface.co/docs/transformers/main_classes/deepspeed#nontrainer-deepspeed-integration\n    if ds_config is not None and ds_config[\"zero_optimization\"][\"stage\"] == 3:\n        dschf = HfDeepSpeedConfig(ds_config)\n    else:\n        dschf = None\n    if rlhf_training:\n        # the weight loading is handled by create critic model\n        model = model_class.from_config(model_config, trust_remote_code=True)\n    else:\n        model = model_class.from_pretrained(\n            model_name_or_path,\n            from_tf=bool(\".ckpt\" in model_name_or_path),\n            config=model_config, trust_remote_code=True)\n\n    model.config.end_token_id = tokenizer.eos_token_id\n    # model.config.pad_token_id = model.config.eos_token_id\n    # model.resize_token_embeddings(int(\n    #     8 *\n    #     math.ceil(len(tokenizer) / 8.0)))  # make the vocab size multiple of 8\n    if \"pangu\" in model_name_or_path:\n        model.config.target_modules = \"q_proj,k_proj,v_proj\"\n        model.config.task_type = \"CAUSAL_LM\"\n    elif \"chatglm\" in model_name_or_path:\n        model.config.target_modules = \"query_key_value\"\n        model.config.task_type = \"SEQ_2_SEQ_LM\"\n    elif \"glm\" in model_name_or_path:\n        model.config.target_modules = \"query_key_value\"\n        model.config.task_type = \"SEQ_2_SEQ_LM\"\n    else:\n        raise ValueError(f\"Unsupported model type: {model_name_or_path}\")\n\n    return model\n\n\ndef create_critic_model(model_name_or_path,\n                        tokenizer,\n                        ds_config,\n                        num_padding_at_beginning=0,\n                        rlhf_training=False,\n                        disable_dropout=False,\n                        checkpoint=None,\n                        lora_rank=0,\n                        lora_alpha=1,\n                        lora_train_bias=\"none\"):\n    # OPT model family always put a padding token at the beginning of the sequence,\n    # we did not see this in other models but not sure if it is a general rule\n    if \"pangu\" in model_name_or_path:\n        model_class = AutoModelForCausalLM\n    elif \"glm\" in model_name_or_path:\n        model_class = AutoModelForSeq2SeqLM\n    else:\n        raise ValueError(f\"Unsupported model type: {model_name_or_path}\")\n    critic_model = create_hf_model(model_class, model_name_or_path, tokenizer,\n                                   ds_config, rlhf_training, disable_dropout)\n    critic_model.config.lora_rank = lora_rank\n    critic_model.config.lora_alpha = lora_alpha\n    critic_model.config.lora_train_bias = lora_train_bias\n    if \"pangu\" in model_name_or_path or \"chatglm\" in model_name_or_path:\n        critic_model = RewardModel(critic_model.config, critic_model.transformer, tokenizer,\n                                   # num_padding_at_beginning=num_padding_at_beginning\n                                   )\n    else:\n        critic_model = RewardModel(critic_model.config, critic_model.glm, tokenizer)\n\n    if rlhf_training:\n        assert os.path.exists(checkpoint), f\"Cannot find reward model checkpoint at {checkpoint}\"\n        critic_model.load_state_dict(torch.load(checkpoint, map_location='cpu'))\n\n    return critic_model\n\n\nclass DeepSpeedRLHFEngine:\n\n    def __init__(self, actor_model_name_or_path, critic_model_name_or_path,\n                 tokenizer, args, num_total_iters):\n        self.args = args\n        self.num_total_iters = num_total_iters\n        self.tokenizer = tokenizer\n        if \"pangu\" in actor_model_name_or_path:\n            self.model_class = AutoModelForCausalLM\n        elif \"glm\" in actor_model_name_or_path:\n            self.model_class = AutoModelForSeq2SeqLM\n        else:\n            raise ValueError(f\"Unsuppported model type: {actor_model_name_or_path}\")\n\n        self.actor = self._init_actor(\n            actor_model_name_or_path=actor_model_name_or_path)\n        self.ref = self._init_ref(\n            actor_model_name_or_path=actor_model_name_or_path)\n        self.actor_ema = None\n        if self.args.enable_ema:\n            self.actor_ema = self._init_ema(\n                actor_model_name_or_path=actor_model_name_or_path)\n\n        self.reward = self._init_reward(\n            critic_model_name_or_path=critic_model_name_or_path)\n        if args.enable_joint_optimization:\n            self.critic = self._init_critic(\n                critic_model_name_or_path=critic_model_name_or_path)\n\n            if self.args.critic_gradient_checkpointing:\n                self.critic.gradient_checkpointing_enable()\n        else:\n            self.critic = None\n\n    def _init_actor(self, actor_model_name_or_path):\n        stime = log_init(\"Actor\", self.args.local_rank)\n\n        # DS Config\n        ds_config = get_train_ds_config(\n            global_batch_size=self.args.global_train_batch_size_actor,\n            micro_batch_size=self.args.ppo_train_batch_size,\n            gradient_accumulation_steps=self.args.gradient_accumulation_steps_actor,\n            offload=self.args.offload,\n            stage=self.args.actor_zero_stage,\n            enable_hybrid_engine=self.args.enable_hybrid_engine,\n            inference_tp_size=self.args.inference_tp_size,\n            release_inference_cache=self.args.release_inference_cache,\n            pin_parameters=(not self.args.unpin_actor_parameters),\n            tp_gather_partition_size=self.args.tp_gather_partition_size,\n            max_out_tokens=self.args.max_length)\n\n        # Model\n        actor_model = create_hf_model(\n            model_class=self.model_class,\n            model_name_or_path=actor_model_name_or_path,\n            tokenizer=self.tokenizer,\n            ds_config=ds_config,\n            disable_dropout=self.args.disable_actor_dropout)\n\n        # LoRA\n        if self.args.actor_lora_rank > 0:\n            config = LoraConfig(\n                r=self.args.actor_lora_rank,\n                lora_alpha=self.args.actor_lora_alpha,\n                target_modules=actor_model.config.target_modules.split(\",\"),\n                lora_dropout=0.05,\n                task_type=actor_model.config.task_type\n            )\n            actor_model.enable_input_require_grads()\n            actor_model = get_peft_model(actor_model, config)\n            # actor_model = convert_linear_layer_to_lora(\n            #     actor_model, self.args.actor_lora_module_name,\n            #     self.args.actor_lora_rank)\n            # if self.args.only_optimize_lora:\n            #     actor_model = only_optimize_lora_parameters(actor_model)\n\n        # Optimizer\n        AdamOptimizer = DeepSpeedCPUAdam if self.args.offload else FusedAdam\n        optim_params = get_optimizer_grouped_parameters(\n            actor_model, self.args.actor_weight_decay)\n        optim = AdamOptimizer(optim_params,\n                              lr=self.args.actor_learning_rate,\n                              betas=(0.9, 0.95))\n\n        # LR Scheduler\n        lr_scheduler = get_scheduler(\n            name=self.args.lr_scheduler_type,\n            optimizer=optim,\n            num_warmup_steps=self.args.warmup_steps,\n            num_training_steps=self.num_total_iters,\n        )\n\n        # DeepSpeed Engine\n        actor_engine, *_ = deepspeed.initialize(model=actor_model,\n                                                optimizer=optim,\n                                                lr_scheduler=lr_scheduler,\n                                                config=ds_config)\n        actor_engine.config['pad_token_id'] = actor_model.config.pad_token_id\n        log_init(\"Actor\", self.args.local_rank, stime=stime)\n\n        return actor_engine\n\n    def _init_ref(self, actor_model_name_or_path):\n        stime = log_init(\"Ref\", self.args.local_rank)\n        # DS Config\n        zero_stage = self.args.actor_zero_stage\n        if zero_stage != 3:\n            # If actor is ZeRO-3 then we use it for everything, otherwise assume we have enough memory for ref model\n            zero_stage = 0\n        ds_config = get_eval_ds_config(global_batch_size=self.args.global_train_batch_size_actor,\n                                       micro_batch_size=self.args.ppo_train_batch_size,\n                                       gradient_accumulation_steps=self.args.eval_accumulation_steps,\n                                       offload=self.args.offload_reference_model,\n                                       stage=zero_stage)\n\n        ref_model = create_hf_model(self.model_class,\n                                    actor_model_name_or_path, self.tokenizer,\n                                    ds_config)\n\n        ref_engine, *_ = deepspeed.initialize(model=ref_model,\n                                              config=ds_config)\n\n        log_init(\"Ref\", self.args.local_rank, stime=stime)\n        return ref_engine\n\n    def _init_ema(self, actor_model_name_or_path):\n        stime = log_init(\"EMA\", self.args.local_rank)\n        # DS Config\n        zero_stage = self.args.actor_zero_stage\n        if zero_stage != 3:\n            # If actor is ZeRO-3 then we use it for everything, otherwise assume we have enough memory\n            zero_stage = 0\n        ds_config = get_eval_ds_config(global_batch_size=self.args.global_train_batch_size_actor,\n                                       micro_batch_size=self.args.ppo_train_batch_size,\n                                       gradient_accumulation_steps=self.args.eval_accumulation_steps,\n                                       offload=self.args.offload_reference_model,\n                                       stage=zero_stage)\n\n        actor_model_ema = create_hf_model(self.model_class,\n                                          actor_model_name_or_path,\n                                          self.tokenizer, ds_config)\n        if self.args.actor_lora_rank > 0:\n            config = LoraConfig(\n                r=self.args.actor_lora_rank,\n                lora_alpha=self.args.actor_lora_alpha,\n                target_modules=actor_model_ema.config.target_modules.split(\",\"),\n                lora_dropout=0.05,\n                task_type=actor_model_ema.config.task_type\n            )\n            actor_model_ema.enable_input_require_grads()\n            actor_model_ema = get_peft_model(actor_model_ema, config)\n            # actor_model_ema = convert_linear_layer_to_lora(\n            #     actor_model_ema, self.args.actor_lora_module_name,\n            #     self.args.actor_lora_rank)\n\n        ema_engine, *_ = deepspeed.initialize(model=actor_model_ema,\n                                              config=ds_config)\n\n        log_init(\"EMA\", self.args.local_rank, stime=stime)\n        return ema_engine\n\n    def _init_critic(self, critic_model_name_or_path):\n        stime = log_init(\"Critic\", self.args.local_rank)\n        ds_config = get_train_ds_config(global_batch_size=self.args.global_train_batch_size_critic,\n                                        micro_batch_size=self.args.ppo_train_batch_size,\n                                        gradient_accumulation_steps=self.args.gradient_accumulation_steps,\n                                        offload=self.args.offload,\n                                        stage=self.args.critic_zero_stage)\n\n        #TODO(jeff): should not be needed, we should be able to use ds_config above\n        #TODO(jeff): it means we never create the critic w. zero.init context if we are using ZeRO-3\n        ds_eval_config = get_eval_ds_config(global_batch_size=self.args.global_train_batch_size_critic,\n                                            micro_batch_size=self.args.ppo_train_batch_size,\n                                            gradient_accumulation_steps=self.args.eval_accumulation_steps,\n                                            offload=False,\n                                            stage=0)\n\n        # Model\n        critic_model = create_critic_model(\n            model_name_or_path=critic_model_name_or_path,\n            tokenizer=self.tokenizer,\n            ds_config=ds_eval_config,\n            # num_padding_at_beginning=self.args.num_padding_at_beginning,\n            rlhf_training=True,\n            disable_dropout=self.args.disable_critic_dropout,\n            checkpoint=self.args.critic_checkpoint,\n            lora_rank=self.args.critic_lora_rank,\n            lora_alpha=self.args.lora_alpha,\n            lora_train_bias=self.args.lora_train_bias)\n\n        # LoRA\n        if self.args.critic_lora_rank > 0:\n            config = LoraConfig(\n                r=self.args.critic_lora_rank,\n                lora_alpha=self.args.critic_lora_alpha,\n                target_modules=critic_model.config.target_modules.split(\",\"),\n                lora_dropout=0.05,\n                task_type=critic_model.config.task_type\n            )\n            critic_model.enable_input_require_grads()\n            critic_model = get_peft_model(critic_model, config)\n            # critic_model = convert_linear_layer_to_lora(\n            #     critic_model, self.args.critic_lora_module_name,\n            #     self.args.critic_lora_rank)\n            # if self.args.only_optimize_lora:\n            #     critic_model = only_optimize_lora_parameters(critic_model)\n\n        # Optimizer\n        AdamOptimizer = DeepSpeedCPUAdam if self.args.offload else FusedAdam\n        optim_pararms = get_optimizer_grouped_parameters(\n            critic_model, self.args.critic_weight_decay)\n        optim = AdamOptimizer(optim_pararms,\n                              lr=self.args.critic_learning_rate,\n                              betas=(0.9, 0.95))\n\n        # LR Scheduler\n        lr_scheduler = get_scheduler(\n            name=self.args.lr_scheduler_type,\n            optimizer=optim,\n            num_warmup_steps=self.args.warmup_steps,\n            num_training_steps=self.num_total_iters,\n        )\n\n        # DeepSpeed Engine\n        critic_engine, *_ = deepspeed.initialize(model=critic_model,\n                                                 optimizer=optim,\n                                                 lr_scheduler=lr_scheduler,\n                                                 config=ds_config)\n\n        log_init(\"Critic\", self.args.local_rank, stime=stime)\n        return critic_engine\n\n    def _init_reward(self, critic_model_name_or_path):\n        stime = log_init(\"Reward\", self.args.local_rank)\n        # DS Config\n        zero_stage = self.args.critic_zero_stage\n        if zero_stage != 3:\n            # If critic is ZeRO-3 then we use it for everything, otherwise assume we have enough memory\n            zero_stage = 0\n        ds_config = get_eval_ds_config(global_batch_size=self.args.global_train_batch_size_critic,\n                                       micro_batch_size=self.args.ppo_train_batch_size,\n                                       gradient_accumulation_steps=self.args.eval_accumulation_steps,\n                                       offload=self.args.offload,\n                                       stage=zero_stage)\n\n        #TODO(jeff): should not be needed, we should be able to use ds_config above\n        #TODO(jeff): it means we never create the critic w. zero.init context if we are using ZeRO-3\n        ds_eval_config = get_eval_ds_config(global_batch_size=self.args.global_train_batch_size_critic,\n                                            micro_batch_size=self.args.ppo_train_batch_size,\n                                            gradient_accumulation_steps=self.args.eval_accumulation_steps,\n                                            offload=False,\n                                            stage=0)\n\n        # Model\n        reward_model = create_critic_model(\n            model_name_or_path=critic_model_name_or_path,\n            tokenizer=self.tokenizer,\n            ds_config=ds_eval_config,\n            # num_padding_at_beginning=self.args.num_padding_at_beginning,\n            rlhf_training=True,\n            checkpoint=self.args.critic_checkpoint,\n            lora_rank=self.args.critic_lora_rank,\n            lora_alpha=self.args.lora_alpha,\n            lora_train_bias=self.args.lora_train_bias)\n\n        reward_engine, *_ = deepspeed.initialize(model=reward_model,\n                                                 config=ds_config)\n\n        log_init(\"Reward\", self.args.local_rank, stime=stime)\n        return reward_engine\n"
  },
  {
    "path": "src/models/trainer.py",
    "content": "from collections import defaultdict\n\nimport deepspeed\nimport sys\nimport json\nimport os\nimport ray\nimport torch\nimport logging\nimport uuid\nimport torch.nn.functional as F\n\nfrom abc import abstractmethod\n\nfrom datasets import Dataset\nfrom deepspeed.runtime.zero import ZeroParamStatus\nfrom time import time\nfrom torch import nn\nfrom tqdm import tqdm\nfrom typing import Any, Callable, Iterable, Dict, List, Optional, Tuple, Union, Literal\nfrom torch.utils.data import DataLoader\nfrom accelerate import Accelerator  # type: ignore\nfrom ray.air import session\ntry:\n    from ray.train import Checkpoint\nexcept ImportError:\n    from ray.air.checkpoint import Checkpoint\nfrom rich.console import Console\nfrom rich.table import Table\nfrom transformers import (\n    AutoTokenizer,\n    PretrainedConfig,\n    Trainer,\n    PreTrainedModel,\n    TrainingArguments,\n    DataCollator,\n    PreTrainedTokenizerBase,\n    TrainerCallback,\n)\nfrom trl.models import create_reference_model\nfrom trl.trainer.utils import disable_dropout_in_model, pad_to_length\n\nfrom src.utils.logger import logger\nfrom src.utils.config import TRLConfig\nfrom src.data.pipeline import BaseRolloutStore\nfrom src.utils.file_utils import significant, print_gpu_utilization, print_gpu_utilization_torch\nfrom src.utils.modeling_utils import (\n    filter_non_scalars,\n    get_distributed_config,\n    get_git_tag,\n    get_optimizer_class,\n    get_scheduler_class,\n    flatten_dict,\n    freeze_bottom_causal_layers,\n    freeze_bottom_seq2seq_layers,\n    get_delta_model_class,\n    parse_delta_kwargs,\n)\n\nfrom src.data.data_types import PromptBatch, PPORLBatch, PPORLElement\nfrom src.models.ppo import (\n    AdaptiveKLController,\n    AutoModelForCausalLMWithHydraValueHead,\n    AutoModelForSeq2SeqLMWithHydraValueHead,\n    FixedKLController,\n)\nfrom src.data.pipeline import BasePipeline, PPORolloutStorage\nfrom src.utils.modeling_utils import Clock, RunningMoments, logprobs_of_labels\nfrom src.utils.logger import logger\n\n# specifies a dictionary of architectures\n_TRAINERS: Dict[str, Any] = {}  # registry\n\n\ndef register_trainer(name):\n    \"\"\"Decorator used to register a trainer\n    Args:\n        name: Name of the trainer type to register\n    \"\"\"\n\n    def register_class(cls, name):\n        _TRAINERS[name] = cls\n        setattr(sys.modules[__name__], name, cls)\n        return cls\n\n    if isinstance(name, str):\n        name = name.lower()\n        return lambda c: register_class(c, name)\n\n    cls = name\n    name = cls.__name__\n    register_class(cls, name.lower())\n\n    return cls\n\n\n@register_trainer\nclass BaseRLTrainer:\n    def __init__(\n            self,\n            config: TRLConfig,\n            reward_fn=None,\n            metric_fn=None,\n            logit_mask=None,\n            stop_sequences=None,\n            train_mode=False,\n    ):\n        self.store: BaseRolloutStore = None\n        self.config = config\n        self.reward_fn = reward_fn\n        self.metric_fn = metric_fn\n        self.train_mode = train_mode\n        self.logit_mask = logit_mask\n        self.stop_sequences = stop_sequences\n\n    def push_to_store(self, data):\n        self.store.push(data)\n\n    def add_eval_pipeline(self, eval_pipeline):\n        \"\"\"Adds pipeline for validation prompts\"\"\"\n        self.eval_pipeline = eval_pipeline\n\n    @abstractmethod\n    def sample(self, prompts: Iterable[str], length: int, n_samples: int) -> Iterable[str]:\n        \"\"\"\n        Sample from the language. Takes prompts and maximum length to generate.\n\n        :param prompts: List of prompts to tokenize and use as context\n\n        :param length: How many new tokens to genrate for each prompt\n        :type length: int\n\n        :param n_samples: Default behavior is to take number of prompts as this\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def learn(\n            self,\n            log_fn: Callable = None,\n            save_fn: Callable = None,\n            eval_fn: Callable = None,\n    ):\n        \"\"\"\n        Use experiences in RolloutStore to learn\n\n        :param log_fn: Optional function that is called when logging and passed a dict of logging relevant values\n        :type log_fn: Callable[Dict[str, any]]\n\n        :param save_fn: Optional function to call after saving. Is passed the components.\n        :type save_fn: Callable[Dict[str, any]]\n\n        :param eval_fn: Optional function to call during evaluation. Eval doesn't do anything without this.\n        :type eval_fn: Callable[BaseRLTrainer]\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def save(self, directory: Optional[str] = None):\n        \"\"\"Creates a checkpoint of training states\"\"\"\n        pass\n\n    @abstractmethod\n    def load(self, directory=None):\n        \"\"\"Loads a checkpoint created from `save`\"\"\"\n        pass\n\n\n@register_trainer\nclass AccelerateRLTrainer(BaseRLTrainer):\n    \"\"\"\n    RL model trainer with an `accelerate` based backend\n    \"\"\"\n\n    def __init__(self, config, **kwargs):  # noqa: C901\n        super().__init__(config, **kwargs)\n        self.max_length = config.train.seq_length\n        self.accelerator = Accelerator(log_with=config.train.tracker, logging_dir=config.train.logging_dir)\n\n        if self.accelerator.state.deepspeed_plugin is not None:\n            # by accelerate's default, arguments in `model.forward` would be casted to half\n            if \"fp16\" in self.accelerator.state.deepspeed_plugin.deepspeed_config:\n                self.accelerator.state.deepspeed_plugin.deepspeed_config[\"fp16\"][\"auto_cast\"] = False\n\n        if int(os.environ.get(\"WORLD_SIZE\", 1)) > 1:\n            torch.distributed.barrier(device_ids=[int(os.environ.get(\"LOCAL_RANK\", 0))])\n\n        self.model = self.setup_model()\n        self.opt = self.setup_optimizer()\n        self.scheduler = self.setup_scheduler()\n\n        self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer.tokenizer_path, trust_remote_code=True)\n        self.tokenizer.padding_side = config.tokenizer.padding_side\n        self.tokenizer.truncation_side = config.tokenizer.truncation_side\n        self.padding_side = config.tokenizer.padding_side\n        # self.tokenizer.sep_token = \"<sep>\"\n        # if config.model.model_arch_type != \"seq2seq\":\n        #     self.tokenizer.pad_token = self.tokenizer.eos_token\n\n        script_name = os.path.basename(sys.argv[0]).rsplit(\".\", 1)[0]\n        if not isinstance(config.model.model_path, str):\n            model_name = str(config.model.model_path).split()[0]\n        else:\n            model_name = config.model.model_path.split(\"/\")[-1]\n\n        if self.accelerator.num_processes == 1:\n            num_gpus = \"1gpu\"\n        else:\n            num_gpus = f\"{self.accelerator.num_processes}gpus\"\n        branch = get_git_tag()[0]\n\n        run_name = \"/\".join([script_name, model_name, num_gpus]) + f\":{branch}\"\n\n        if self.accelerator.is_main_process and not ray.is_initialized():\n            config_dict = self.config.to_dict()\n            dist_config = get_distributed_config(self.accelerator)\n            config_dict[\"distributed\"] = dist_config\n            init_trackers_kwargs = {}\n\n            if config.train.tracker == \"wandb\":\n                init_trackers_kwargs[\"wandb\"] = {\n                    \"name\": run_name,\n                    \"entity\": self.config.train.entity_name,\n                    \"group\": self.config.train.group_name,\n                    \"tags\": [\"/\".join(get_git_tag())],\n                    \"mode\": \"disabled\" if os.environ.get(\"debug\", False) else \"online\",\n                }\n\n                self.accelerator.init_trackers(\n                    project_name=self.config.train.project_name,\n                    config=config_dict,\n                    init_kwargs=init_trackers_kwargs,\n                )\n            elif config.train.tracker == \"tensorboard\":\n                # flatten config for tensorboard, split list in hparams into flatten config\n                config_dict_flat = flatten_dict(config_dict)\n                config_dict_flat[\"optimizer/kwargs/beta_1\"] = config_dict_flat[\"optimizer/kwargs/betas\"][0]\n                config_dict_flat[\"optimizer/kwargs/beta_2\"] = config_dict_flat[\"optimizer/kwargs/betas\"][1]\n                config_dict_flat.pop(\"optimizer/kwargs/betas\", None)\n                self.accelerator.init_trackers(\n                    project_name=self.config.train.project_name,\n                    config=config_dict_flat,\n                )\n            elif config.train.tracker is None:\n                self.accelerator.init_trackers(project_name=self.config.train.project_name)\n            else:\n                raise ValueError(\n                    f\"Only supported trackers are `wandb` and `tensorboard`. Got: `{config.train.tracker}`. \"\n                    \"Set `tracker` to `None` to disable tracking.\"\n                )\n\n    def setup_model(self):\n        \"\"\"\n        Returns a model derived from an instance's TRLConfig\n        \"\"\"\n        logger.info(f\"Initializing model: {self.config.model.model_path}\")\n\n        # Retrieves model equipped for ppo, ilql, etc\n        model = self.get_arch(self.config)\n        # if self.config.model.model_arch_type == \"seq2seq\":\n        #     freeze_bottom_seq2seq_layers(model.base_model, self.config.model.num_layers_unfrozen)\n        # else:\n        freeze_bottom_causal_layers(model.base_model, self.config.model.num_layers_unfrozen)\n        # Set the delta tuning strategies\n        if self.config.model.delta_kwargs is not None:\n            delta_type, delta_kwargs = parse_delta_kwargs(\n                model.base_model.config,\n                self.config.model.delta_kwargs,\n                self.config.model.num_layers_unfrozen,\n            )\n            delta_model_class = get_delta_model_class(delta_type)\n            delta_model = delta_model_class(model.base_model, **delta_kwargs)\n            delta_model.freeze_module(exclude=[\"deltas\"], set_state_dict=True)\n            if self.accelerator.is_main_process:\n                delta_model.log()\n        return model\n\n    def setup_optimizer(self):\n        \"\"\"\n        Returns an optimizer derived from an instance's TRLConfig\n        \"\"\"\n        optimizer_class = get_optimizer_class(self.config.optimizer.name)\n        optimizer = optimizer_class(\n            self.model.parameters(),\n            **self.config.optimizer.kwargs,\n        )\n\n        if \"bitsandbytes\" in optimizer.__class__.__module__:\n            # Force 32-bit `nn.Embedding` weights for stability. See discussion:\n            # https://github.com/huggingface/transformers/issues/14819#issuecomment-1016017746\n            from bitsandbytes.optim import GlobalOptimManager\n\n            manager = GlobalOptimManager.get_instance()\n            for module in self.model.modules():\n                if isinstance(module, torch.nn.Embedding):\n                    manager.register_module_override(module, \"weight\", {\"optim_bits\": 32})\n\n        return optimizer\n\n    def setup_scheduler(self):\n        \"\"\"\n        Returns a learning rate scheduler derived from an instance's TRLConfig\n        \"\"\"\n        scheduler_class = get_scheduler_class(self.config.scheduler.name)\n        scheduler = scheduler_class(self.opt, **self.config.scheduler.kwargs)\n        return scheduler\n\n    def decode(\n            self,\n            prompts: List[torch.LongTensor],\n            samples: List[torch.LongTensor],\n            prompt_sizes: torch.LongTensor = None,\n    ) -> Tuple[List[str], List[str], List[str], List[str], List[List[torch.Tensor]]]:\n        \"\"\"\n        Decode tensor generations into lists of strings (`samples`: List[str], `prompts`: List[str], `outputs`: List[str])\n        \"\"\"\n        # Assuming prompts were left-padded\n        prompt_sizes = []\n        prefix_indices = []\n        for prompt in prompts:\n            prefix_idx = None\n            if \"chatglm\" in self.config.model.model_path:\n                prompt_sizes.append(len(prompt))\n            else:\n                logger.debug(f\"[decode] prompt: {prompt}\")\n                if isinstance(prompt, torch.Tensor):\n                    prompt = prompt.cpu().detach().tolist()\n                prompt_sizes.append(prompt.index(self.tokenizer.sep_token_id))\n                if \"glm\" in self.config.model.model_path:\n                    try:\n                        prefix_idx = prompt.index(self.tokenizer.mask_token_id)\n                    except IndexError:\n                        pass\n            prefix_indices.append(prefix_idx)\n\n        str_samples, str_prompts, str_outputs, str_prefixes, sample_outputs = [], [], [], [], []\n        for prompt, sample, prompt_size, prefix_idx in zip(prompts, samples, prompt_sizes, prefix_indices):\n            # if self.config.model.model_arch_type == \"seq2seq\":\n            #     output_start_ix = 0\n            # else:\n            output_start_ix = prompt_size\n\n            str_prompt = self.tokenizer.decode(prompt[:prompt_size], skip_special_tokens=True)\n            if prefix_idx is not None:\n                str_prefix = self.tokenizer.decode(sample[output_start_ix:prefix_idx], skip_special_tokens=True)\n                sample_output = sample[prefix_idx:]\n                str_output = self.tokenizer.decode(sample_output, skip_special_tokens=True)\n            else:\n                str_prefix = None\n                sample_output = sample[output_start_ix:]\n                str_output = self.tokenizer.decode(sample_output, skip_special_tokens=True)\n\n            # Trim outputs up to `self.stop_sequences` if any are present\n            if self.stop_sequences:\n                for stop in self.stop_sequences:\n                    stop_ix = str_output.find(stop)\n                    if stop_ix >= 0:\n                        str_output = str_output[:stop_ix].rstrip()\n\n            str_prompts.append(str_prompt)\n            str_outputs.append(str_output)\n            str_prefixes.append(str_prefix)\n            sample_outputs.append(sample_output)\n\n            if \"chatglm\" in self.config.model.model_path:\n                sample = str_prompt + str_output\n            else:\n                sample = str_prompt + self.tokenizer.sep_token + str_output\n\n            str_samples.append(sample)\n\n        return str_samples, str_prompts, str_outputs, str_prefixes, sample_outputs\n\n    def generate(self, input_ids, attention_mask=None, **kwargs):\n        \"\"\"Wraps hf's `generate` adding some specific method's defaults\"\"\"\n        input_ids = input_ids.to(self.accelerator.device)\n        if attention_mask is not None:\n            attention_mask = attention_mask.to(self.accelerator.device)\n        if self.generate_experience_kwargs is not None:\n            kwargs = dict(self.generate_experience_kwargs, **kwargs)\n        else:\n            kwargs = dict(self.generate_kwargs, **kwargs)\n\n        with torch.no_grad():\n            return self.accelerator.unwrap_model(self.model).generate(\n                input_ids=input_ids, attention_mask=attention_mask, **kwargs\n            )\n\n    def generate_eval(self, input_ids, attention_mask=None, **kwargs):\n        \"\"\"Wraps hf's `generate` adding some specific method's defaults\"\"\"\n        input_ids = input_ids.to(self.accelerator.device)\n        if attention_mask is not None:\n            attention_mask = attention_mask.to(self.accelerator.device)\n\n        kwargs = dict(self.generate_kwargs, **kwargs)\n\n        with torch.no_grad():\n            return self.accelerator.unwrap_model(self.model).generate(\n                input_ids=input_ids, attention_mask=attention_mask, **kwargs\n            )\n\n    def save_pretrained(self, directory: Optional[str] = None, **kwargs):\n        \"\"\"Save the underlying Hugging Face model, tokenizer, and configuration files to a directory for\n        later use.\n\n        Args:\n            directory (str, *optional*): The directory to save the trainer files to.\n                NOTE: If not specified, the model will be saved to a directory named `hf_model` in the\n                checkpoint directory as specified by the Trainer's config.\n            **kwargs: Additional keyword arguments passed to the underlying Hugging Face model's\n                `save_pretrained` method.\n        \"\"\"\n        if directory is None:\n            directory = os.path.join(self.config.train.checkpoint_dir, \"hf_model\")\n        self.accelerator.wait_for_everyone()\n        self.accelerator.unwrap_model(self.model).save_pretrained(directory, **kwargs)\n        if self.accelerator.is_main_process:\n            self.tokenizer.save_pretrained(directory)\n\n    def save(self, directory: Optional[str] = None, **kwargs):\n        \"\"\"Creates a checkpoint of the optimizer, scheduler and model\"\"\"\n        self.accelerator.save_state(directory or self.config.train.checkpoint_dir, **kwargs)\n\n    def load(self, directory: Optional[str] = None, **kwargs):\n        \"\"\"Load checkpoint of optimizer, scheduler and a model\"\"\"\n        self.accelerator.load_state(directory or self.config.train.checkpoint_dir, **kwargs)\n\n    def add_eval_pipeline(self, eval_pipeline):\n        \"\"\"Adds pipeline from with validation prompts\"\"\"\n        self.eval_pipeline = eval_pipeline\n\n    def evaluate(self):  # noqa: C901\n        \"\"\"Samples model on `eval_prompts`, logs stats with `reward_fn` or `metric_fn` if provided\"\"\"\n        logger.info(\"Evaluating model\")\n\n        # Do multiple evaluations over a single list in `gen_kwargs` if present\n        if self.generate_sweep_kwarg is not None:\n            gen_sweep_arg, gen_sweep_values = self.generate_sweep_kwarg\n        else:\n            gen_sweep_values = [None]\n\n        desc = [\n            f\"generation sweep 0/{len(gen_sweep_values)}\",\n            f\"eval batch 0/{len(self.eval_dataloader)}\",\n        ]\n        tbar = tqdm(\n            total=len(self.eval_dataloader) * len(gen_sweep_values),\n            desc=f\"[{' | '.join(desc)}]\",\n            disable=not self.accelerator.is_main_process,\n            position=0,\n            leave=True,\n        )\n\n        stats = {}\n        table = []\n\n        for i_sweep, gen_sweep_value in enumerate(gen_sweep_values):\n            # A dedicated suffix for wandb logging\n            if gen_sweep_value is not None:\n                sweep_suffix = f\"@{gen_sweep_arg}={gen_sweep_value}\"\n            else:\n                sweep_suffix = \"\"\n\n            all_samples = []\n            all_prompts = []\n            all_prompt_sizes = []\n            generate_time = time()\n            for i_prompt, prompts in enumerate(self.eval_dataloader):\n                logger.debug(f\"evaluate() - prompts keys: {prompts.keys()}, input_ids: {prompts['input_ids'].shape}\")\n                if self.generate_sweep_kwarg:\n                    samples = self.generate_eval(**prompts, **{gen_sweep_arg: gen_sweep_value})\n                else:\n                    samples = self.generate_eval(**prompts)\n\n                # if self.config.model.model_arch_type == \"seq2seq\":\n                #     samples = samples[:, 1:].contiguous()\n\n                prompt_sizes = torch.tensor(prompts['input_ids'].shape[1]).repeat(len(prompts['input_ids']))\n                prompts, samples, prompt_sizes = self.accelerator.gather_for_metrics(\n                    self.accelerator.pad_across_processes(\n                        [prompts['input_ids'], samples, prompt_sizes.to(samples.device)],\n                        dim=1,\n                        pad_index=self.tokenizer.pad_token_id,\n                    )\n                )\n                all_samples.extend(samples.tolist())\n                all_prompts.extend(prompts.tolist())\n                all_prompt_sizes.extend(prompt_sizes.tolist())\n\n                desc = [\n                    f\"generation sweep {i_sweep + 1}/{len(gen_sweep_values)}\",\n                    f\"eval batch {i_prompt + 1}/{len(self.eval_dataloader)}\",\n                ]\n                tbar.set_description(f\"[{' | '.join(desc)}]\")\n                tbar.update()\n            tbar.close()\n\n            stats[\"time/generate\"] = time() - generate_time\n\n            if self.accelerator.is_main_process:\n                str_samples, str_prompts, str_outputs, str_prefixes, _ = self.decode(all_prompts, all_samples, all_prompt_sizes)\n\n                columns = [\"prompt\", \"output\"]\n                columns_data = [str_prompts, str_outputs]\n\n                # in online setting, compute the reward for validation\n                if self.reward_fn:\n                    logger.info(\"Computing rewards\")\n                    rewards = torch.tensor(\n                        self.reward_fn(\n                            samples=str_samples,\n                            prompts=str_prompts,\n                            outputs=str_outputs,\n                        ),\n                        dtype=float,\n                    )\n                    mean_reward = rewards.mean().item()\n                    columns.append(\"reward\")\n                    if not isinstance(rewards, list):\n                        rewards = rewards.tolist()\n                    columns_data.append(rewards)\n                    stats[f\"reward/mean{sweep_suffix}\"] = mean_reward\n\n                # additionally log any other metrics\n                if self.metric_fn:\n                    logger.info(\"Computing metrics\")\n                    metric_time = time()\n                    metrics = self.metric_fn(\n                        samples=str_samples,\n                        prompts=str_prompts,\n                        outputs=str_outputs,\n                    )\n                    stats[\"time/metric\"] = time() - metric_time\n\n                    mean_metrics = {\n                        f\"metrics/{k}{sweep_suffix}\": torch.as_tensor(xs).mean(-1) for k, xs in metrics.items()\n                    }\n\n                    stats.update(mean_metrics)\n\n                    for metric, values in metrics.items():\n                        columns.append(metric)\n                        if not isinstance(values, list):\n                            values = values.tolist()\n                        columns_data.append(values)\n\n                # Prepend the sweep argument along with samples\n                if self.generate_sweep_kwarg:\n                    columns.insert(0, gen_sweep_arg)\n                    columns_data.insert(0, [gen_sweep_value] * len(samples))\n\n                table.append(list(zip(*columns_data)))\n\n        # Log and display evaluation metrics\n        logger.info(\"Summarizing evaluation\")\n        if self.accelerator.is_main_process:\n            rows = sum(list(map(list, zip(*table))), [])\n\n            # Add metrics/rewards to the table's title\n            table_title = f\"Evaluation #{self.nth_evaluation}\"\n            for k, x in stats.items():\n                if k.startswith(\"reward\") or k.startswith(\"metrics\"):\n                    table_title += f\" {k}: {significant(x)}\"\n\n            rich_table = Table(*columns, title=table_title, show_lines=True)\n            for ix in range(max(min(3, len(rows)), len(gen_sweep_values))):\n                rich_table.add_row(*[str(significant(x)) for x in rows[ix]])\n            Console().print(rich_table)\n\n            if not ray.is_initialized():\n                if self.config.train.tracker == \"wandb\":\n                    import wandb\n\n                    stats[\"samples\"] = wandb.Table(columns, rows)\n\n        self.nth_evaluation += 1\n        return stats\n\n    def learn(self):  # noqa: C901\n        \"\"\"\n        Samples batches from `self.store`, updates model and periodically evaluates it on `self.eval_dataloader`\n        \"\"\"\n        logger.info(\"Starting training\")\n\n        self.generate_sweep_kwarg = None\n        for k, v in self.config.method.gen_kwargs.items():\n            if isinstance(v, list):\n                if self.generate_sweep_kwarg is not None:\n                    logger.info(\"Only a single sweep is allowed, {k} is going to be set to {v[0]}\")\n                    self.generate_kwargs[k] = v[0]\n                else:\n                    self.generate_sweep_kwarg = (k, v)\n\n        self.prepare_learning()\n        self.iter_count = 0\n        self.nth_evaluation = 0\n\n        if ray.is_initialized():\n            checkpoint = session.get_checkpoint()\n            if checkpoint:\n                with checkpoint.as_directory() as dir:\n                    self.accelerator.load_state(dir)\n\n                    with open(os.path.join(dir, \"state.json\")) as f:\n                        state = json.load(f)\n                        self.iter_count = state[\"iter_count\"]\n        else:\n            results = self.evaluate()\n            self.accelerator.log(results, step=self.iter_count)\n\n        tbar = tqdm(\n            initial=self.iter_count,\n            total=self.total_steps,\n            disable=not self.accelerator.is_local_main_process,\n            position=0,\n            leave=True,\n        )\n\n        best_reward = -float(\"inf\")\n\n        # For each epoch\n        for _ in range(self.config.train.epochs):\n            # For each batch\n            for batch in self.train_dataloader:\n                # For each update per batch\n                for _ in range(self.n_updates_per_batch):\n                    # Note that whereas standard policy gradient methods perform one\n                    # gradient update per batch, PPO for example commonly performs\n                    # multiple gradient updates on the same batch of data.\n                    # https://arxiv.org/pdf/1707.06347.pdf\n                    forward_time = time()\n                    loss, stats = self.loss(batch)\n                    forward_time = time() - forward_time\n                    backward_time = time()\n                    self.accelerator.backward(loss)\n                    backward_time = time() - backward_time\n\n                    self.opt.step()\n                    self.opt.zero_grad()\n                    self.scheduler.step()\n                    self.iter_count += 1\n\n                    if self.iter_count % self.config.train.checkpoint_interval == 0:\n                        subfolder = f\"checkpoint_{self.iter_count:0{len(str(self.total_steps))}d}\"\n                        directory = os.path.join(self.config.train.checkpoint_dir, subfolder)\n                        self.save(directory)\n\n                    stats[\"time/forward\"] = forward_time\n                    stats[\"time/backward\"] = backward_time\n                    for group_number, lr in enumerate(self.scheduler.get_last_lr()):\n                        stats[f\"learning_rate_group_{group_number}\"] = lr\n\n                    if self.iter_count % self.config.train.eval_interval == 0:\n                        results = self.evaluate()\n                        stats.update(results)\n\n                        # always save checkpoint with the greatest mean reward\n                        if self.config.train.save_best:\n                            if stats.get(\"reward/mean\", -float(\"inf\")) > best_reward:\n                                best_reward = stats.get(\"reward/mean\")\n                                do_save = True\n                            # in case ILQL reports reward estimate as one of its metrics\n                            elif stats.get(\"metrics/reward\", -float(\"inf\")) > best_reward:\n                                best_reward = stats.get(\"metrics/reward\")\n                                do_save = True\n                            else:\n                                do_save = False\n                            do_save = torch.tensor(do_save, device=self.accelerator.device)\n                            if torch.distributed.is_initialized():\n                                torch.distributed.all_reduce(do_save, torch.distributed.ReduceOp.MAX)\n                            if do_save:\n                                best_path = f\"{self.config.train.checkpoint_dir}/best_checkpoint\"\n                                logger.info(f\"Saving the best state so far into {best_path}\")\n                                self.save(best_path)\n\n                        # Report the metrics to Ray Tune.\n                        if ray.is_initialized():\n                            self.save(\"state\")\n                            with open(\"state/state.json\", \"w\") as f:\n                                json.dump(dict(iter_count=self.iter_count), f)\n                            checkpoint = Checkpoint.from_directory(\"state\")\n                            session.report(filter_non_scalars(stats), checkpoint=checkpoint)\n\n                    if not ray.is_initialized():\n                        self.accelerator.log(stats, step=self.iter_count)\n\n                    desc = \" | \".join(f\"{k}: {v:.2f}\" for k, v in stats.items() if k.startswith(\"loss\"))\n                    tbar.set_description(f\"[{desc}]\")\n                    tbar.update()\n\n                    if self.iter_count >= self.total_steps:\n                        subfolder = f\"checkpoint_{self.iter_count:0{len(str(self.total_steps))}d}\"\n                        directory = os.path.join(self.config.train.checkpoint_dir, subfolder)\n                        self.save(directory)\n                        return self.evaluate()\n\n                self.post_backward_callback()\n\n            self.post_epoch_callback()\n        tbar.close()\n\n    @abstractmethod\n    def get_arch(self, config: TRLConfig):\n        \"\"\"Returns a specific wrapper of the decoder architecture\"\"\"\n        pass\n\n    @abstractmethod\n    def loss(self, batch) -> Tuple[float, Dict]:\n        \"\"\"Compute loss on a batch from `store` and return some statistics\"\"\"\n        pass\n\n    @abstractmethod\n    def post_backward_callback(self):\n        \"\"\"Do something after model update\"\"\"\n        pass\n\n    @abstractmethod\n    def post_epoch_callback(self):\n        \"\"\"Do something after exhausting/single pass over `self.store`\"\"\"\n        pass\n\n\n@register_trainer\nclass AcceleratePPOTrainer(AccelerateRLTrainer):\n    \"\"\"PPO Accelerate Trainer\"\"\"\n\n    reward_fn: Callable[[List[str], List[str], List[str]], List[float]]\n    tokenizer: AutoTokenizer\n\n    def __init__(self, config: TRLConfig, **kwargs):\n        \"\"\"PPO Accelerate Trainer initialization\n\n        Args:\n            config: Config\n        \"\"\"\n        super().__init__(config, **kwargs)\n\n        # Setup rollout logging\n        if config.train.rollout_logging_dir is not None:\n            self.log_rollouts = True\n            self.setup_rollout_logging(config)\n        else:\n            self.log_rollouts = False\n\n        # Setup the rollout store\n        # Rollouts contain the prompt & response, log probs, values and rewards - from each rollout\n        self.store = PPORolloutStorage(self.tokenizer.pad_token_id)\n\n        # Create the rollout store dataloader (for batching up rollouts)\n        # TODO (jon-tow): This is only used to satisfy to `accelerator.prepare` call constraint below - remove in future\n        rollout_loader: DataLoader = self.store.create_loader(self.config.train.batch_size, shuffle=True)\n\n        # Prepare multi-GPU acceleration\n        self.model, self.opt, self.scheduler, rollout_loader = self.accelerator.prepare(\n            self.model, self.opt, self.scheduler, rollout_loader\n        )\n\n        self.store.clear_history()  # Clear the rollout store\n\n        # Setup a reference model when hydra heads are not used\n        if not hasattr(self.model, \"frozen_head\"):\n            self.ref_model = self.get_arch(self.config)\n            self.ref_model.to(self.accelerator.device)\n            self.ref_model.eval()\n\n        # Setup the KL controller\n        # This helps prevent large divergences in the controller (policy)\n        if config.method.target is not None:\n            self.kl_ctl = AdaptiveKLController(config.method.init_kl_coef, config.method.target, config.method.horizon)\n        else:\n            self.kl_ctl = FixedKLController(config.method.init_kl_coef)\n\n        # Create the parameters for the Hugging Face language model's generator\n        # method (that generates new tokens from a prompt).\n        # https://huggingface.co/docs/transformers/v4.25.1/en/main_classes/text_generation#transformers.GenerationMixin.generate\n        if config.model.model_arch_type == \"seq2seq\":\n            self.generate_kwargs = dict(\n                config.method.gen_kwargs,\n                eos_token_id=self.tokenizer.eop_token_id,\n                pad_token_id=self.tokenizer.pad_token_id,\n            )\n            if config.method.gen_experience_kwargs is not None:\n                self.generate_experience_kwargs = dict(\n                    config.method.gen_experience_kwargs,\n                    eos_token_id=self.tokenizer.eop_token_id,\n                    pad_token_id=self.tokenizer.pad_token_id,\n                )\n            else:\n                self.generate_experience_kwargs = None\n        else:\n            self.generate_kwargs = dict(\n                config.method.gen_kwargs,\n                eos_token_id=self.tokenizer.eos_token_id,\n                pad_token_id=self.tokenizer.pad_token_id,\n            )\n            if config.method.gen_experience_kwargs is not None:\n                self.generate_experience_kwargs = dict(\n                    config.method.gen_experience_kwargs,\n                    eos_token_id=self.tokenizer.eos_token_id,\n                    pad_token_id=self.tokenizer.pad_token_id,\n                )\n            else:\n                self.generate_experience_kwargs = None\n\n        # Setup stats tracker\n        self.running_moments = RunningMoments()\n        self.ref_mean = self.config.method.ref_mean\n        self.ref_std = self.config.method.ref_std\n\n    def get_arch(self, config: TRLConfig):\n        \"\"\"Get the model\"\"\"\n        model_class = AutoModelForCausalLMWithHydraValueHead\n        if config.model.model_arch_type == \"seq2seq\":\n            model_class = AutoModelForSeq2SeqLMWithHydraValueHead\n\n        from_fn = model_class.from_pretrained\n        # backward-compat: Try to create a randomly initialized architecture from a config\n        if issubclass(type(config.model.model_path), PretrainedConfig):\n            from_fn = model_class.from_config\n\n        model = from_fn(\n            config.model.model_path,\n            trust_remote_code=True,\n            num_layers_unfrozen=config.model.num_layers_unfrozen,\n            config=config\n        )\n\n        return model\n\n    def loss(self, batch: PPORLBatch):\n        \"\"\"Forward pass & loss\n\n        Args:\n            batch: Previous batch of episodes\n        \"\"\"\n        # Move `batch` data to `accelerator` device\n        input_ids = batch.query_tensors.to(self.accelerator.device)\n        response_tensors = batch.response_tensors.to(self.accelerator.device)\n        attention_mask = batch.attention_mask.to(self.accelerator.device)\n        old_logprobs = batch.logprobs.to(self.accelerator.device)\n        old_values = batch.values.to(self.accelerator.device)\n        old_rewards = batch.rewards.to(self.accelerator.device)\n        response_length = old_rewards.shape[1]\n        logger.debug(f\"loss() - input ids shape: {input_ids.shape}, attention mask shape: {attention_mask.shape}\")\n\n        advantages, returns = self.config.method.get_advantages_and_returns(old_values, old_rewards, response_length)\n\n        if self.config.model.model_arch_type == \"seq2seq\":\n            # TODO: To be modified for glm and chatglm\n            # input_ids = query_tensors\n            decoder_input_ids = response_tensors\n            # attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device)\n            position_ids = torch.stack(batch.position_ids).to(self.accelerator.device)\n            # decoder_attention_mask = (\n            #     decoder_input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device)\n            # )\n            # decoder_attention_mask[:, 0] = 1\n            logger.debug(f\"loss() - position ids shape: {position_ids.shape}\")\n\n            # Forward pass\n            outputs = self.model(\n                input_ids=input_ids,\n                attention_mask=attention_mask,\n                position_ids=position_ids\n                # decoder_input_ids=decoder_input_ids,\n                # decoder_attention_mask=decoder_attention_mask,\n            )\n\n            logits = outputs.logits\n            values_pred = outputs.value\n            logprobs = logprobs_of_labels(logits[:, :-1, :], decoder_input_ids[:, 1:])\n            mask = decoder_input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device)\n            start = 0\n            end = start + response_length\n            logprobs, values_pred, mask = (\n                logprobs[:, start:end],\n                values_pred[:, start:end],\n                mask[:, start:end],\n            )\n        else:\n            # tokens = torch.cat((query_tensors, response_tensors), dim=1)\n            # attention_mask = tokens.not_equal(self.tokenizer.pad_token_id).long().to(tokens.device)\n            outputs = self.model(input_ids, attention_mask, return_dict=True)\n            logits = outputs.logits\n            values_pred = outputs.value\n            logger.info(f\"loss() - s1 values_pred shape: {values_pred.shape}\")\n            values_pred = values_pred[:, :-1]\n            logger.info(f\"loss() - s2 values_pred shape: {values_pred.shape}\")\n            logprobs = logprobs_of_labels(logits[:, :-1, :], input_ids[:, 1:])\n\n            start = input_ids.shape[1] - 1\n            end = start + response_length\n            logprobs, values_pred, mask = (\n                logprobs[:, start:end],\n                values_pred[:, start:end],\n                attention_mask[:, start:end],\n            )\n            logger.info(f\"loss() - s3 values_pred shape: {values_pred.shape}\")\n\n        # TODO: need debugging here\n        loss, stats = self.config.method.loss(\n            logprobs=logprobs,\n            values=values_pred,\n            old_logprobs=old_logprobs,\n            old_values=old_values,\n            advantages=advantages,\n            returns=returns,\n            mask=mask,\n        )\n\n        return loss, stats\n\n    def setup_rollout_logging(self, config):\n        # Make rollout logging dir for this run and store config\n        exists = os.path.exists(config.train.rollout_logging_dir)\n        isdir = os.path.isdir(config.train.rollout_logging_dir)\n        assert exists and isdir\n\n        self.run_id = f\"run-{uuid.uuid4()}\"\n        self.rollout_logging_dir = os.path.join(config.train.rollout_logging_dir, self.run_id)\n        os.mkdir(self.rollout_logging_dir)\n\n        with open(os.path.join(self.rollout_logging_dir, \"config.json\"), \"w\") as f:\n            f.write(json.dumps(config.to_dict(), indent=2))\n\n    def post_epoch_callback(self):\n        \"\"\"Post epoch callback\n\n        Clears the store and creates `num_rollouts` new episodes.\n        \"\"\"\n        if self.log_rollouts:\n            self.store.export_history(location=self.rollout_logging_dir)\n        self.store.clear_history()\n        # Collect more rollouts for training\n        self.make_experience(self.config.method.num_rollouts, self.iter_count)\n\n    def post_backward_callback(self):\n        self.kl_ctl.update(self.mean_kl.item(), n_steps=self.config.train.batch_size)\n\n    def prepare_learning(self):\n        eval_dataloader = self.eval_pipeline.create_loader(self.config.method.chunk_size)\n        self.eval_dataloader = self.accelerator.prepare_data_loader(eval_dataloader)\n        self.train_dataloader = self.store.create_loader(self.config.train.batch_size, shuffle=True)\n\n        self.n_updates_per_batch = self.config.method.ppo_epochs\n        self.total_steps = self.config.train.epochs * self.n_updates_per_batch * len(self.train_dataloader)\n        self.total_steps = min(self.total_steps, self.config.train.total_steps)\n\n    def add_prompt_pipeline(self, pipeline: BasePipeline):\n        \"\"\"Add a prompt pipeline dataloader to a trainer instance for the `make_experience` stage\"\"\"\n        prompt_dataloader = pipeline.create_loader(self.config.method.chunk_size, shuffle=True)\n        self.prompt_dataloader = self.accelerator.prepare_data_loader(prompt_dataloader)\n        self.prompt_iterator = iter(self.prompt_dataloader)\n\n    def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0):  # noqa:\n        \"\"\"Make experiences\n\n        Takes `chunk_size` number of prompts from `prompt_iterator`, samples\n        from the model and then computes the KL against a reference model. Finally it\n        then appends PPOElements to trainer's `store`.\n\n        Args:\n            num_rollouts: Number of rollouts to generate\n            iter_count: Total number of updates run (i.e. number of updates run for all batches & epochs)\n        \"\"\"\n        logger.info(\"Collecting rollouts\")\n        tbar = tqdm(\n            total=num_rollouts,\n            disable=os.environ.get(\"RANK\", 0) != \"0\",\n            desc=f\"[rollout 0 / {num_rollouts}]\",\n            # Lower progress bar by 1 if we're in WARNING mode or above to avoid hiding high priority progress\n            # bars (e.g. loss progress in trainers)\n            position=logger.level >= logging.WARNING,\n            # Leave progress bar if we're in INFO mode or lower to avoid spamming in suppressed verbosity levels\n            leave=logger.level < logging.WARNING,\n        )\n\n        ppo_rl_elements = []\n        stats = {}\n        clock = Clock()\n\n        while len(ppo_rl_elements) < num_rollouts:\n            # Get next batch in prompt dataset and refresh if exhausted\n            # TOOD (jon-tow): Make `prompt_dataloader` a cyclic/infinite DataLoader to not require manually\n            # \"refreshing\" the contents of the `prompt_iterator`\n            try:\n                batch: PromptBatch = next(self.prompt_iterator)\n            except StopIteration:\n                self.prompt_iterator = iter(self.prompt_dataloader)\n                batch = next(self.prompt_iterator)\n\n            exp_generate_time = time()\n\n            # Generate samples from the language model (similar to using HuggingFace `generate` method)\n            logger.debug(f\"generate() input `batch` keys: {batch.keys()}\")\n            samples = self.generate(**batch)\n            for i in range(len(batch['input_ids'])):\n                p = self.tokenizer.decode(batch['input_ids'][i], skip_special_tokens=True)\n                gt = self.tokenizer.decode(samples[i], skip_special_tokens=True)\n                logger.debug(f\"prompt: {p}, generated result: {gt}, samples: {samples[i]}\")\n            logger.debug(f\"make_experience() - input ids shape: {batch['input_ids'].shape}, samples shape: {samples.shape}\")\n            stats[\"time/exp_generate\"] = time() - exp_generate_time\n\n            prompt_tensors = batch['input_ids']\n            device = samples.device\n\n            prompt_sizes = torch.tensor([prompt_tensors.shape[1]] * len(prompt_tensors), device=device)\n            padded_samples = self.accelerator.pad_across_processes(\n                samples, dim=1, pad_index=self.tokenizer.pad_token_id, pad_first=False\n            )\n            padded_prompts = self.accelerator.pad_across_processes(\n                prompt_tensors, dim=1, pad_index=self.tokenizer.pad_token_id, pad_first=False\n            )\n            gathered_samples = self.accelerator.gather(padded_samples)\n            gathered_prompts = self.accelerator.gather(padded_prompts)\n            gathered_prompt_sizes = self.accelerator.gather(prompt_sizes)\n\n            if self.accelerator.is_main_process:\n                all_str_samples, all_str_prompts, all_str_outputs, all_str_prefixes, _ = self.decode(\n                    gathered_prompts, gathered_samples, gathered_prompt_sizes\n                )\n\n                exp_score_time = time()\n                all_scores = torch.tensor(\n                    self.reward_fn(\n                        samples=all_str_samples,\n                        prompts=all_str_prompts,\n                        outputs=all_str_outputs,\n                    ),\n                    dtype=torch.float,\n                    device=device,\n                )\n                stats[\"time/exp_score\"] = time() - exp_score_time\n\n                all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1).unbind())\n            else:\n                all_scores = None\n\n            if torch.distributed.is_initialized():\n                scores = torch.empty(len(samples), device=device)\n                torch.distributed.scatter(scores, all_scores)\n            else:\n                scores = torch.tensor(all_scores[0])\n\n            str_samples, str_prompts, str_outputs, str_prefixes, outputs = self.decode(prompt_tensors, samples)\n\n            # Pad the sample outputs\n            # outputs = self.tokenizer(str_outputs).input_ids\n            # if self.config.model.model_arch_type == \"seq2seq\":\n            #     # add <pad> to the start of the output\n            #     for i in range(len(outputs)):\n            #         outputs[i] = [self.tokenizer.pad_token_id] + outputs[i]\n            # outputs = list(map(torch.LongTensor, outputs))\n            maxsize = max(map(len, outputs))\n            outputs = [\n                F.pad(\n                    output,\n                    (0, maxsize - len(output)),\n                    value=self.tokenizer.pad_token_id,\n                )\n                for output in outputs\n            ]\n            sample_outputs = torch.vstack(outputs).to(device)\n\n            # store statistics of the initial rollout as reference\n            if self.ref_mean is None:\n                self.ref_mean, self.ref_std = scores.mean(), scores.std()\n            all_scores_mean, all_scores_std = self.running_moments.update(scores)\n            stats[\"exp_scores/mean\"] = all_scores_mean\n            stats[\"exp_scores/std\"] = all_scores_std\n            stats[\"exp_scores/running_mean\"] = self.running_moments.mean\n            stats[\"exp_scores/running_std\"] = self.running_moments.std\n\n            if self.config.method.scale_reward == \"running\":\n                scores /= self.running_moments.std\n            elif self.config.method.scale_reward == \"ref\":\n                scores /= self.ref_std\n\n            clip_reward = self.config.method.cliprange_reward\n            if clip_reward:\n                scores = torch.clip(scores, -clip_reward, clip_reward)\n\n            # Precompute logprobs, values\n            logger.debug(f\"sample_outputs shape: {sample_outputs.shape}\")\n            logger.debug(f\"str_prompts[0]: {str_prompts[0]}, str_outputs[0]: {str_outputs[0]}, input_ids[0]: {batch['input_ids'][0]}, sample_outputs[0]: {sample_outputs[0]}\")\n            # logger.debug(f\"str_prompts[1]: {str_prompts[1]}, str_outputs[1]: {str_outputs[1]}, input_ids[1]: {batch['input_ids'][1]}, sample_outputs[1]: {sample_outputs[1]}\")\n            self.tokenizer.padding_side = \"right\"\n            if self.config.model.model_arch_type == \"seq2seq\":\n                input_ids, attention_mask, position_ids = [], [], []\n                for str_prompt, str_output, str_prefix in zip(str_prompts, str_outputs, str_prefixes):\n                    encoded_prompt = self.tokenizer(str_prompt, str_prefix + self.tokenizer.mask_token)\n                    prompt_length = len(encoded_prompt['input_ids'])\n                    label_length = len(self.tokenizer.tokenize(str_output)) + 1\n                    if prompt_length + label_length > self.max_length:\n                        num_tokens_to_remove = prompt_length + label_length - self.max_length\n                        for _ in range(num_tokens_to_remove):\n                            if prompt_length > label_length:\n                                prompt_length -= 1\n                            else:\n                                label_length -= 1\n                    else:\n                        label_length = self.max_length - prompt_length\n                    assert prompt_length > 0\n                    assert label_length > 0\n                    assert prompt_length + label_length <= self.max_length\n                    encoded_dict = self.tokenizer(str_prompt, str_prefix + self.tokenizer.mask_token,\n                                                  max_length=prompt_length,\n                                                  truncation=\"only_first\",\n                                                  return_tensors=\"pt\",\n                                                  return_attention_mask=True,\n                                                  return_token_type_ids=False)\n                    encoded_dict = self.tokenizer.build_inputs_for_generation(encoded_dict, targets=str_output,\n                                                                              max_gen_length=label_length, padding=True)\n                    input_ids.append(encoded_dict['input_ids'])\n                    attention_mask.append(encoded_dict['attention_mask'])\n                    position_ids.append(encoded_dict['position_ids'])\n                input_ids = torch.cat(input_ids).to(device)\n                attention_mask = torch.cat(attention_mask).to(device)\n                position_ids = torch.cat(position_ids).to(device)\n                with torch.no_grad():\n                    outputs = self.model(\n                        input_ids=input_ids,\n                        attention_mask=attention_mask,\n                        position_ids=position_ids\n                    )\n                    logits = outputs.logits\n                    values = outputs.value\n                    if hasattr(self.model, \"frozen_head\"):\n                        ref_logits = self.model.forward_hydra(\n                            input_ids=input_ids,\n                            attention_mask=attention_mask,\n                            position_ids=position_ids,\n                            # decoder_input_ids=sample_outputs,\n                            # decoder_attention_mask=decoder_attention_mask,\n                            return_dict=True,\n                        ).logits\n                    else:\n                        ref_logits = self.ref_model(\n                            input_ids=input_ids,\n                            attention_mask=attention_mask,\n                            position_ids=position_ids,\n                            # decoder_input_ids=sample_outputs,\n                            # decoder_attention_mask=decoder_attention_mask,\n                            return_dict=True,\n                        ).logits\n\n            else:\n                # all_tokens = torch.cat((prompt_tensors.to(device), sample_outputs), dim=1)\n                # attention_mask = all_tokens.not_equal(self.tokenizer.pad_token_id).long().to(device)\n                encoded_dict = self.tokenizer(str_prompts, str_outputs, max_length=self.max_length, return_tensors=\"pt\",\n                                              truncation=\"longest_first\", padding=\"max_length\", return_token_type_ids=False)\n                input_ids = encoded_dict['input_ids'].to(device)\n                attention_mask = encoded_dict['attention_mask'].to(device)\n                position_ids = None\n                with torch.no_grad():\n                    logits, *_, values = self.model(\n                        input_ids,\n                        attention_mask=attention_mask,\n                    )\n                    # TODO(dahoas): When hydra model works need to also support generation on hydra head\n                    if hasattr(self.model, \"frozen_head\"):\n                        ref_logits = self.model.forward_hydra(\n                            input_ids,\n                            attention_mask=attention_mask,\n                            return_dict=True,\n                        ).logits\n                    else:\n                        ref_logits = self.ref_model(\n                            input_ids,\n                            attention_mask=attention_mask,\n                            return_dict=True,\n                        ).logits\n                        ref_logits = ref_logits.to(device)\n            self.tokenizer.padding_side = self.padding_side\n\n            if self.config.model.model_arch_type == \"seq2seq\":\n                # TODO: to be tested against glm and chatglm\n                logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:])\n                ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:])\n            else:\n                logprobs = logprobs_of_labels(logits[:, :-1, :], input_ids[:, 1:])\n                ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], input_ids[:, 1:])\n\n            n_samples: int = samples.shape[0]\n            logprobs = logprobs.cpu()\n            ref_logprobs = ref_logprobs.cpu()\n            # prompt_tensors = prompt_tensors.cpu()\n            sample_outputs = sample_outputs.cpu()\n            input_ids = input_ids.cpu()\n            attention_mask = attention_mask.cpu()\n            position_ids = position_ids.cpu() if position_ids is not None else None\n            values = values.cpu()[:, :-1]\n\n            # Estimate the KL divergence between the model and reference model\n            if self.config.model.model_arch_type == \"seq2seq\":\n                # TODO: to be modified for glm and chatglm\n                attention_mask_tmp = sample_outputs != self.tokenizer.pad_token_id\n                start = 0\n            else:\n                attention_mask_tmp = attention_mask\n                start = prompt_tensors.shape[1] - 1\n\n            ends = start + attention_mask_tmp[:, start:].sum(1)\n\n            # Get the logprobs and values, for tokens that are not padding\n            # or beginning of sequences tokens. These are from the model (not the reference model)\n            all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)]\n            all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)]\n\n            log_ratio = (logprobs - ref_logprobs) * attention_mask_tmp[:, :-1].cpu()\n            self.mean_kl = (log_ratio.exp() - 1 - log_ratio).mean().to(device)\n            kl_penalty = self.kl_ctl.value * -log_ratio\n            kl_penalty = [xs[start : ends[ix]] for ix, xs in enumerate(kl_penalty)]\n\n            rollout_count = 0\n\n            for sample_idx in range(n_samples):\n                if len(kl_penalty[sample_idx]) == 0 or len(all_logprobs[sample_idx]) == 0:\n                    continue\n\n                rewards = kl_penalty[sample_idx]\n                rewards[-1] += scores[sample_idx].cpu()\n\n                logger.debug(f\"make_experience() - attention mask shape: {attention_mask[sample_idx].shape}\")\n\n                ppo_rl_elements.append(\n                    PPORLElement(\n                        query_tensor=input_ids[sample_idx],\n                        # query_tensor=prompt_tensors[sample_idx],\n                        response_tensor=sample_outputs[sample_idx],\n                        attention_mask=attention_mask[sample_idx],\n                        position_ids=position_ids[sample_idx] if position_ids is not None else None,\n                        logprobs=all_logprobs[sample_idx],\n                        values=all_values[sample_idx],\n                        rewards=rewards,\n                    )\n                )\n\n                rollout_count += 1\n            exp_time = clock.tick()\n            tbar.set_description(f\"[rollout {len(ppo_rl_elements)} / {num_rollouts}]\")\n            tbar.update(min(rollout_count, num_rollouts))\n        tbar.close()\n\n        if torch.distributed.is_initialized():\n            torch.distributed.all_reduce(self.mean_kl, torch.distributed.ReduceOp.AVG)\n\n        stats[\"policy/sqrt_kl\"] = torch.sqrt(self.mean_kl)\n        stats[\"kl_ctl_value\"] = self.kl_ctl.value\n        stats[\"time/exp\"] = exp_time\n\n        if not ray.is_initialized():\n            self.accelerator.log(stats, step=iter_count)\n\n        # Push samples and rewards to trainer's rollout storage\n        self.push_to_store(ppo_rl_elements)\n\n\ndef get_model_norm(model):\n    with torch.no_grad():\n        total = 0.0\n        for param in model.parameters():\n            should_gather = hasattr(\n                param,\n                'ds_id') and param.ds_status == ZeroParamStatus.NOT_AVAILABLE\n            with deepspeed.zero.GatheredParameters(param,\n                                                   enabled=should_gather):\n                total += float(param.float().norm())\n\n    return total\n\n\ndef gather_log_probs(logits, labels):\n    log_probs = F.log_softmax(logits, dim=-1)\n    log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))\n    return log_probs_labels.squeeze(-1)\n\n\nclass DeepSpeedPPOTrainer():\n\n    def __init__(self, rlhf_engine, args):\n        self.rlhf_engine = rlhf_engine\n        self.actor_model = self.rlhf_engine.actor\n        self.critic_model = self.rlhf_engine.critic\n        self.ref_model = self.rlhf_engine.ref\n        self.reward_model = self.rlhf_engine.reward\n        self.tokenizer = self.rlhf_engine.tokenizer\n        self.args = args\n        self.max_answer_seq_len = args.max_gen_length\n        # self.end_of_conversation_token_id = self.tokenizer(\n        #     args.end_of_conversation_token)['input_ids'][-1]\n        self.end_of_conversation_token_id = self.tokenizer.eos_token_id\n\n        # Those value can be changed\n        self.kl_ctl = args.kl_coefficient\n        self.clip_reward_value = args.clip_reward_value\n        self.cliprange = args.clip_range\n        self.cliprange_value = args.clip_range_value\n        self.gamma = args.gamma\n        self.lam = args.lambda_\n\n    def generate_sequence(self, inputs):\n        self.eval()\n        print_gpu_utilization(\"generate_sequence - before model.generate\", self.args.local_rank)\n        print_gpu_utilization_torch(\"generate_sequence - before model.generate\", self.args.local_rank)\n        batch_size = inputs['input_ids'].shape[0]\n        prompt_length = inputs['input_ids'].shape[-1]\n\n        with torch.no_grad():\n            logger.debug(f\"[generate_sequence] inputs: {inputs}\")\n            prompts = []\n            answers = []\n            outputs = dict()\n            for i in range(batch_size):\n                input = {k: v[i].unsqueeze(0) for k, v in inputs.items()}\n                prompt = self.tokenizer.decode(input['input_ids'][0], skip_special_tokens=False)\n                if \"pangu\" in self.args.actor_model_path:\n                    seq = self.actor_model.module.generate(**input,\n                                                           max_new_tokens=self.max_answer_seq_len,\n                                                           pad_token_id=self.tokenizer.pad_token_id,\n                                                           do_sample=self.args.do_sample,\n                                                           num_return_sequences=self.args.num_return_sequences,\n                                                           top_p=self.args.top_p,\n                                                           temperature=self.args.temperature)\n                    for output_ids in seq:\n                        answer = self.tokenizer.decode(output_ids[prompt_length:], skip_special_tokens=True)\n                        # Since prompt has <sep>, cannot use tokenizer(prompts, answers). Therefore concat prompt and answer, use tokenizer(prompt+answer) instead\n                        prompts.append(prompt + answer)\n                elif \"chatglm\" in self.args.actor_model_path:\n                    seq = self.actor_model.module.generate(**input,\n                                                           max_new_tokens=self.max_answer_seq_len,\n                                                           eos_token_id=self.tokenizer.eop_token_id,\n                                                           pad_token_id=self.tokenizer.pad_token_id,\n                                                           do_sample=self.args.do_sample,\n                                                           num_return_sequences=self.args.num_return_sequences,\n                                                           top_p=self.args.top_p,\n                                                           temperature=self.args.temperature)\n                    logger.debug(f\"[generate_sequence] seq: {seq}\")\n                    for output_ids in seq:\n                        answer = self.tokenizer.decode(output_ids[prompt_length:], skip_special_tokens=True)\n                        prompts.append(prompt)\n                        answers.append(answer)\n                elif \"glm\" in self.args.actor_model_path:\n                    seq = self.actor_model.module.generate(**input,\n                                                           max_new_tokens=self.max_answer_seq_len,\n                                                           eos_token_id=self.tokenizer.eop_token_id,\n                                                           pad_token_id=self.tokenizer.pad_token_id,\n                                                           do_sample=self.args.do_sample,\n                                                           num_return_sequences=self.args.num_return_sequences,\n                                                           top_p=self.args.top_p,\n                                                           temperature=self.args.temperature)\n                    for output_ids in seq:\n                        answer = self.tokenizer.decode(output_ids[prompt_length:], skip_special_tokens=True)\n                        label_length = len(self.tokenizer.tokenize(answer)) + 1\n                        if prompt_length + label_length > self.args.max_length:\n                            num_tokens_to_remove = prompt_length + label_length - self.args.max_length\n                            for _ in range(num_tokens_to_remove):\n                                if prompt_length > label_length:\n                                    prompt_length -= 1\n                                else:\n                                    label_length -= 1\n                        else:\n                            label_length = self.args.max_length - prompt_length\n                        assert prompt_length > 0\n                        assert label_length > 0\n                        assert prompt_length + label_length == self.args.max_length\n                        encoded_dict = self.tokenizer(prompt,\n                                                      max_length=prompt_length,\n                                                      return_tensors=\"pt\",\n                                                      return_attention_mask=True,\n                                                      return_token_type_ids=False,\n                                                      add_special_tokens=False)\n                        encoded_dict = self.tokenizer.build_inputs_for_generation(encoded_dict,\n                                                                                  targets=answer,\n                                                                                  max_gen_length=label_length,\n                                                                                  padding=True)\n                        for key, val in encoded_dict.items():\n                            if key not in outputs:\n                                outputs[key] = []\n                            outputs[key].append(val[0])\n                else:\n                    raise ValueError(f\"Unsupported model name: {self.args.actor_model_path}\")\n\n            if \"pangu\" in self.args.actor_model_path:\n                outputs = self.tokenizer(prompts, max_length=self.args.max_length,\n                                         padding=\"max_length\", return_tensors=\"pt\", return_token_type_ids=False)\n                logger.debug(f\"[generate_sequence] outputs['input_ids'].shape: {outputs['input_ids'].shape}, outputs: {outputs}\")\n            elif \"chatglm\" in self.args.actor_model_path:\n                outputs = self.tokenizer(prompts, answers, max_length=self.args.max_length,\n                                         padding=\"max_length\", return_tensors=\"pt\")\n                logger.debug(f\"[generate_sequence] outputs['input_ids'].shape: {outputs['input_ids'].shape}, outputs: {outputs}\")\n            elif \"glm\" in self.args.actor_model_path:\n                outputs = {key: torch.stack(val) for key, val in outputs.items()}\n                logger.debug(f\"[generate_sequence] outputs['input_ids'].shape: {outputs['input_ids'].shape}, outputs: {outputs}\")\n            else:\n                raise ValueError(f\"Unsupported model name: {self.args.actor_model_path}\")\n        print_gpu_utilization(\"generate_sequence - after model.generate\", self.args.local_rank)\n        print_gpu_utilization_torch(\"generate_sequence - after model.generate\", self.args.local_rank)\n        # Filter out seq with no asnwers (or very short). This happens when users directly use the pre-training ckpt without supervised finetuning\n        # NOTE: this will causes each GPU has different number of examples\n        # ans = seq[:, prompt_length:]\n        # self.prompt_length = prompt_length\n        # valid_ans_len = (ans != self.tokenizer.pad_token_id).sum(dim=-1)\n        # out_seq = []\n        # for i in range(batch_size):\n        #     # if the answer is shorter than 1 token, drop it\n        #     if valid_ans_len[i] <= 1:\n        #         continue\n        #     else:\n        #         out_seq.append(seq[i:i + 1])\n        # out_seq = torch.cat(out_seq, dim=0)  # concat output in the batch dim\n        # logger.debug(f\"[generate_sequence] out_seq: {out_seq}\")\n\n        return outputs, prompt_length\n\n    def generate_experience(self, output_sequences, answer_start_indices):\n        self.eval()\n        print_gpu_utilization(\"generate_experience - before call actor and critic\", self.args.local_rank)\n        print_gpu_utilization_torch(\"generate_experience - before call actor and critic\", self.args.local_rank)\n\n        # pad_token_id = self.tokenizer.pad_token_id\n        input_ids = output_sequences['input_ids']\n        attention_mask = output_sequences['attention_mask'] if \"attention_mask\" in output_sequences else None\n        position_ids = output_sequences['position_ids'] if \"position_ids\" in output_sequences else None\n        print_gpu_utilization(\"generate_experience - after setting output_sequences device\", self.args.local_rank)\n        print_gpu_utilization_torch(\"generate_experience - after setting output_sequences device\", self.args.local_rank)\n\n        with torch.no_grad():\n            output = self.actor_model(input_ids, attention_mask=attention_mask, position_ids=position_ids)\n            output_ref = self.ref_model(input_ids, attention_mask=attention_mask, position_ids=position_ids)\n            output_reward = self.reward_model(input_ids, attention_mask, position_ids)\n            reward_score = output_reward['chosen_reward'].detach()\n            if self.critic_model is not None:\n                values = self.critic_model(input_ids, attention_mask, position_ids)['chosen_values'].detach()\n            else:\n                values = output_reward['chosen_values'].detach()\n        print_gpu_utilization(\"generate_experience - after call actor and critic\", self.args.local_rank)\n        print_gpu_utilization_torch(\"generate_experience - after call actor and critic\", self.args.local_rank)\n\n        logits = output.logits\n        logits_ref = output_ref.logits\n\n        return {\n            # 'prompts': inputs['input_ids'],\n            'answer_start_indices': answer_start_indices,\n            'input_ids': input_ids,\n            'attention_mask': attention_mask,\n            'position_ids': position_ids,\n            'logprobs': gather_log_probs(logits[:, :-1, :], input_ids[:, 1:]),\n            'ref_logprobs': gather_log_probs(logits_ref[:, :-1, :], input_ids[:, 1:]),\n            'value': values[:, :-1],\n            'rewards': reward_score\n        }\n\n    def compute_rewards(self, starts, log_probs, ref_log_probs, reward_score, action_mask):\n        '''\n\n        :param starts: List of indices of the starting index of answer\n        :param log_probs: shape=batch_size * (max_length-1)\n        :param ref_log_probs: shape=batch_size * (max_length-1)\n        :param reward_score: shape=batch_size\n        :param action_mask: shape=batch_size * (answer_length)\n        :return:\n        '''\n        logger.debug(f\"[compute_rewards] log_probs: {log_probs.shape}, ref_log_probs: {ref_log_probs.shape}, \"\n                    f\"reward_score: {reward_score.shape}, action_mask: {action_mask.shape}\")\n        kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_log_probs)\n        rewards = kl_divergence_estimate\n        logger.debug(f\"before rewards: {rewards.shape}\")\n        # start = prompts.shape[1] - 1\n        # ends = start + action_mask.sum(1)\n        sums = action_mask.sum(1)\n        reward_clip = torch.clamp(reward_score, -self.clip_reward_value,\n                                  self.clip_reward_value)\n        batch_size = log_probs.shape[0]\n        for j in range(batch_size):\n            logger.debug(f\"j={j}, sums[j]={sums[j]}, rewards[j, start:ends[j]]: {rewards[j, starts[j]:(starts[j]+sums[j])].shape}\")\n            rewards[j, starts[j]:(starts[j]+sums[j])][-1] += reward_clip[j]\n        logger.debug(f\"after rewards: {rewards.shape}\")\n        return rewards\n\n    def train_rlhf(self, inputs):\n        # process the old outputs\n        answer_start_indices = inputs['answer_start_indices']\n        log_probs = inputs['logprobs'] # shape=batch_size * (max_length-1)\n        ref_log_probs = inputs['ref_logprobs'] # shape=batch_size * (max_length-1)\n        reward_score = inputs['rewards'] # shape=batch_size\n        values = inputs['value'] # shape=batch_size * (max_length-1)\n        attention_mask = inputs['attention_mask'] # shape=batch_size * max_length or shape=batch_size * max_length * max_length\n        position_ids = inputs['position_ids'] # shape=batch_size * 2 * max_length\n        input_ids = inputs['input_ids'] # shape=batch_size * max_length\n        logger.debug(f\"[train_rlhf] answer_start_indices: {answer_start_indices}, \"\n                     f\"log_probs shape: {log_probs.shape}, ref_log_probs shape: {ref_log_probs.shape}, \"\n                     f\"reward_score shape: {reward_score.shape}, values shape: {values.shape}, \"\n                     f\"attention_mask shape: {attention_mask.shape if attention_mask is not None else None},\"\n                     f\"position_ids shape: {position_ids.shape if position_ids is not None else None},\"\n                     f\"input_ids shape: {input_ids.shape}\")\n\n        batch_size = input_ids.size()[0]\n        if attention_mask is not None and len(attention_mask.shape) == 2:\n            # action_mask = attention_mask[:, 1:][:, start:]\n            action_mask = attention_mask[:, 1:]\n        else:\n            # answer_ids = input_ids[:, 1:][:, start:]\n            # batch_size = answer_ids.shape[0]\n            # answer_length = answer_ids.shape[-1]\n            answer_length = input_ids.shape[-1] - 1\n            action_mask = torch.ones((batch_size, answer_length), dtype=torch.long, device=input_ids.device)\n            for i, j in (input_ids[:, 1:] == self.tokenizer.pad_token_id).nonzero():\n                action_mask[i, j] = 0\n        for i in range(batch_size):\n            # set mask of prompt to 0\n            action_mask[i, :answer_start_indices[i]] = 0\n        logger.debug(f\"[train_rlhf] action_mask shape: {action_mask.shape}\")\n\n        # compute advantages and returns\n        print_gpu_utilization(\"train_rlhf - before compute reward and advantages\", self.args.local_rank)\n        print_gpu_utilization_torch(\"train_rlhf - before compute reward and advantages\", self.args.local_rank)\n        old_values = values\n        with torch.no_grad():\n            old_rewards = self.compute_rewards(answer_start_indices, log_probs,\n                                               ref_log_probs, reward_score,\n                                               action_mask)\n            advantages, returns = self.get_advantages_and_returns(old_values, old_rewards, answer_start_indices)\n            logger.debug(f\"[train_rlhf] old_rewards shape: {old_rewards.shape}, advantages shape: {advantages.shape}, returns shape: {returns.shape}\")\n        print_gpu_utilization(\"train_rlhf - after compute reward and advantages\", self.args.local_rank)\n        print_gpu_utilization_torch(\"train_rlhf - after compute reward and advantages\", self.args.local_rank)\n\n        # update actor and critic\n        self.train()\n        batch = {'input_ids': input_ids, \"attention_mask\": attention_mask, \"position_ids\": position_ids}\n        actor_prob = self.actor_model(**batch, use_cache=False).logits # shape=batch_size * max_length * vocab_size\n        print_gpu_utilization(\"train_rlhf - after self.actor_model\", self.args.local_rank)\n        print_gpu_utilization_torch(\"train_rlhf - after self.actor_model\", self.args.local_rank)\n        actor_log_prob = gather_log_probs(actor_prob[:, :-1, :],  input_ids[:, 1:])\n        actor_loss = self.actor_loss_fn(actor_log_prob,\n                                        log_probs, advantages,\n                                        action_mask)\n        self.actor_model.backward(actor_loss)\n        print_gpu_utilization(\"train_rlhf - after actor backward\", self.args.local_rank)\n        print_gpu_utilization_torch(\"train_rlhf - after actor backward\", self.args.local_rank)\n        self.actor_model.step()\n        print_gpu_utilization(\"train_rlhf - after actor step\", self.args.local_rank)\n        print_gpu_utilization_torch(\"train_rlhf - after actor step\", self.args.local_rank)\n\n        if self.critic_model is not None:\n            value = self.critic_model.reward(**batch, use_cache=False)[0][:, :-1] # shape=batch_size * (max_length-1)\n            print_gpu_utilization(\"train_rlhf - after self.critic_model\", self.args.local_rank)\n            print_gpu_utilization_torch(\"train_rlhf - after self.critic_model\", self.args.local_rank)\n            critic_loss = self.critic_loss_fn(value, old_values,\n                                              returns, action_mask)\n            self.critic_model.backward(critic_loss)\n            print_gpu_utilization(\"train_rlhf - after critic backward\", self.args.local_rank)\n            print_gpu_utilization_torch(\"train_rlhf - after critic backward\", self.args.local_rank)\n            self.critic_model.step()\n            print_gpu_utilization(\"train_rlhf - after critic step\", self.args.local_rank)\n            print_gpu_utilization_torch(\"train_rlhf - after critic step\", self.args.local_rank)\n        else:\n            critic_loss = None\n\n        return actor_loss, critic_loss\n\n    def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):\n        ## Clipped Surrogate Objective for policy update in PPO (https://arxiv.org/abs/1707.06347)\n        log_ratio = (logprobs - old_logprobs) * mask\n        ratio = torch.exp(log_ratio)\n        pg_objective1 = advantages * ratio\n        pg_objective2 = advantages * torch.clamp(ratio, 1.0 - self.cliprange,\n                                                 1.0 + self.cliprange)\n        pg_objective = torch.sum(torch.min(pg_objective1, pg_objective2) * mask) / mask.sum()\n        return -pg_objective\n\n    def critic_loss_fn(self, values, old_values, returns, mask):\n        # TODO: Clipped surrogate objective for value function (? not seen in original paper)\n        values_clipped = torch.clamp(\n            values,\n            old_values - self.cliprange_value,\n            old_values + self.cliprange_value,\n            )\n        # Squared-error loss of value function (https://arxiv.org/abs/1707.06347)\n        vf_loss1 = (values - returns)**2\n        vf_loss2 = (values_clipped - returns)**2\n        # TODO: using max puts a lower bound and no uppper bound on the loss, is this really desired?\n        vf_loss = 0.5 * torch.sum(\n            torch.max(vf_loss1, vf_loss2) * mask) / mask.sum()\n        return vf_loss\n\n    def get_advantages_and_returns(self, values, rewards, starts):\n        '''\n\n        :param values: shape=batch_size * (max_length-1)\n        :param rewards: shape=batch_size * (max_length-1)\n        :param start: List of indices of the starting index of answer\n        :return:\n        '''\n        # Generalized advantage estimation (https://arxiv.org/abs/1707.06347)\n        logger.debug(f\"[get_advantages_and_returns] values: {values.shape}, rewards: {rewards.shape}, starts: {starts}\")\n        batch_size = rewards.size()[0]\n        length = rewards.size()[-1]\n\n        # lastgaelam = 0\n        # advantages_reversed = []\n        # for t in reversed(range(start, length)):\n        #     nextvalues = values[:, t + 1] if t < length - 1 else 0.0\n        #     delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]\n        #     lastgaelam = delta + self.gamma * self.lam * lastgaelam\n        #     advantages_reversed.append(lastgaelam)\n        # advantages = torch.stack(advantages_reversed[::-1], dim=1)\n        # logger.debug(f\"advantages: {advantages.shape}, values[:, start:]: {values[:, start:].shape}\")\n        # returns = advantages + values[:, start:]\n\n        advantages = []\n        returns = []\n        for i in range(batch_size):\n            lastgaelam = 0\n            advantages_reversed = []\n            for t in reversed(range(starts[i], length)):\n                nextvalues = values[i, t + 1] if t < length - 1 else 0.0\n                delta = rewards[i, t] + self.gamma * nextvalues - values[i, t]\n                lastgaelam = delta + self.gamma * self.lam * lastgaelam\n                advantages_reversed.append(lastgaelam)\n            # set advantage of prompt to 0 (will be ignored when multiplied with action_mask)\n            advantages_reversed.extend([0]*starts[i])\n            advantage = torch.tensor(advantages_reversed[::-1], device=values.device, dtype=values.dtype)\n            advantages.append(advantage)\n            returns.append(advantage + values[i])\n        advantages = torch.stack(advantages)\n        returns = torch.stack(returns)\n\n        return advantages.detach(), returns\n\n    def _validate_training_mode(self):\n        assert self.actor_model.module.training\n        if self.critic_model is not None:\n            assert self.critic_model.module.training\n\n    def _validate_evaluation_mode(self):\n        assert not self.actor_model.module.training\n        assert not self.ref_model.module.training\n        if self.critic_model is not None:\n            assert not self.critic_model.module.training\n        assert not self.reward_model.module.training\n\n    def train(self):\n        self.actor_model.train()\n        if self.critic_model is not None:\n            self.critic_model.train()\n\n    def eval(self):\n        self.actor_model.eval()\n        self.ref_model.eval()\n        if self.critic_model is not None:\n            self.critic_model.eval()\n        self.reward_model.eval()\n\n    def dump_model_norms(self, tag):\n        actor_model_norm = get_model_norm(self.actor_model)\n        ref_model_norm = get_model_norm(self.ref_model)\n        if self.critic_model is not None:\n            critic_model_norm = get_model_norm(self.critic_model)\n        reward_model_norm = get_model_norm(self.reward_model)\n        if self.args.global_rank <= 0:\n            logger.info(f'{tag} global_actor_model_norm', actor_model_norm,\n                            self.args.local_rank)\n            logger.info(f'{tag} global_ref_model_norm', ref_model_norm,\n                            self.args.local_rank)\n            if self.critic_model is not None:\n                logger.info(f'{tag} global_critic_model_norm', critic_model_norm,\n                                self.args.local_rank)\n            logger.info(f'{tag} global_reward_model_norm', reward_model_norm,\n                            self.args.local_rank)\n\n\nclass DeepSpeedPPOPTXTrainer(DeepSpeedPPOTrainer):\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def train_unsupervised(self, inputs, unsup_coef):\n        # Train the unsupervised model here\n        self._validate_training_mode()\n\n        outputs = self.actor_model(**inputs, use_cache=False)\n        loss = outputs.loss\n        self.actor_model.backward(unsup_coef * loss)\n        self.actor_model.step()\n\n        return loss\n\n\nclass DPOTrainer(Trainer):\n    r\"\"\"\n    Initialize DPOTrainer.\n\n    Args:\n        model (`transformers.PreTrainedModel`):\n            The model to train, preferably an `AutoModelForSequenceClassification`.\n        ref_model (`PreTrainedModelWrapper`):\n            Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no\n            reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.\n        beta (`float`, defaults to 0.1):\n            The beta factor in DPO loss. Higher beta means less divergence from the initial policy.\n        args (`transformers.TrainingArguments`):\n            The arguments to use for training.\n        data_collator (`transformers.DataCollator`):\n            The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used\n            which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.\n        label_pad_token_id (`int`, defaults to `-100`):\n            The label pad token id. This argument is required if you want to use the default data collator.\n        padding_value (`int`, defaults to `0`):\n            The padding value. This argument is required if you want to use the default data collator.\n        truncation_mode (`str`, defaults to `keep_end`):\n            The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator.\n        train_dataset (`datasets.Dataset`):\n            The dataset to use for training.\n        eval_dataset (`datasets.Dataset`):\n            The dataset to use for evaluation.\n        tokenizer (`transformers.PreTrainedTokenizerBase`):\n            The tokenizer to use for training. This argument is required if you want to use the default data collator.\n        model_init (`Callable[[], transformers.PreTrainedModel]`):\n            The model initializer to use for training. If None is specified, the default model initializer will be used.\n        callbacks (`List[transformers.TrainerCallback]`):\n            The callbacks to use for training.\n        optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):\n            The optimizer and scheduler to use for training.\n        preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):\n            The function to use to preprocess the logits before computing the metrics.\n        max_length (`int`, defaults to `None`):\n            The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.\n        max_prompt_length (`int`, defaults to `None`):\n            The maximum length of the prompt. This argument is required if you want to use the default data collator.\n        peft_config (`Dict`, defaults to `None`):\n            The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.\n        disable_dropout (`bool`, defaults to `True`):\n            Whether or not to disable dropouts in `model` and `ref_model`.\n    \"\"\"\n\n    def __init__(\n            self,\n            model: Union[PreTrainedModel, nn.Module] = None,\n            ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None,\n            logps: Optional[Dict[str, Dict[int, Dict[str, torch.FloatTensor]]]] = None,\n            beta: float = 0.1,\n            args: TrainingArguments = None,\n            data_collator: Optional[DataCollator] = None,\n            label_pad_token_id: int = -100,\n            padding_value: int = 0,\n            # truncation_mode: str = \"keep_end\",\n            train_dataset: Optional[Dataset] = None,\n            eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,\n            tokenizer: Optional[PreTrainedTokenizerBase] = None,\n            model_init: Optional[Callable[[], PreTrainedModel]] = None,\n            callbacks: Optional[List[TrainerCallback]] = None,\n            optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (\n                    None,\n                    None,\n            ),\n            preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,\n            # max_length: Optional[int] = None,\n            # max_prompt_length: Optional[int] = None,\n            # peft_config: Optional[Dict] = None,\n            disable_dropout: bool = True,\n    ):\n        self.is_peft_model = getattr(model, \"is_peft_model\", False)\n\n        if logps is not None:\n            self.logps = logps\n            self.ref_model = None\n        else:\n            self.logps = None\n            if ref_model:\n                self.ref_model = ref_model\n            elif self.is_peft_model:\n                # The `model` with adapters turned off will be used as the reference model\n                self.ref_model = None\n            else:\n                self.ref_model = create_reference_model(model)\n\n        # if data_collator is None:\n        #     if tokenizer is None:\n        #         raise ValueError(\n        #             \"max_length or a tokenizer must be specified when using the default DPODataCollatorWithPadding\"\n        #         )\n        #     if max_length is None:\n        #         logger.warn(\n        #             \"When using DPODataCollatorWithPadding, you should set `max_length` in the DPOTrainer's init\"\n        #             \" it will be set to `512` by default, but you should do it yourself in the future.\",\n        #             UserWarning,\n        #         )\n        #         max_length = 512\n        #     if max_prompt_length is None:\n        #         logger.warn(\n        #             \"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the DPOTrainer's init\"\n        #             \" it will be set to `128` by default, but you should do it yourself in the future.\",\n        #             UserWarning,\n        #         )\n        #         max_prompt_length = 128\n        #\n        #     data_collator = DPODataCollatorWithPadding(\n        #         tokenizer,\n        #         max_length=max_length,\n        #         max_prompt_length=max_prompt_length,\n        #         label_pad_token_id=label_pad_token_id,\n        #         padding_value=padding_value,\n        #         truncation_mode=truncation_mode,\n        #     )\n        #\n        #     if args.remove_unused_columns:\n        #         args.remove_unused_columns = False\n        #         # warn users\n        #         warnings.warn(\n        #             \"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments\"\n        #             \" we have set it for you, but you should do it yourself in the future.\",\n        #             UserWarning,\n        #         )\n        #\n        #     self.use_dpo_data_collator = True\n        # else:\n        #     self.use_dpo_data_collator = False\n\n        if disable_dropout:\n            disable_dropout_in_model(model)\n            if self.ref_model is not None:\n                disable_dropout_in_model(self.ref_model)\n\n        self.label_pad_token_id = label_pad_token_id\n        self.padding_value = padding_value\n\n        self.beta = beta\n\n        self._stored_metrics = defaultdict(lambda: defaultdict(list))\n\n        super().__init__(\n            model,\n            args,\n            data_collator,\n            train_dataset,\n            eval_dataset,\n            tokenizer,\n            model_init,\n            None,\n            callbacks,\n            optimizers,\n            preprocess_logits_for_metrics,\n        )\n\n        if not hasattr(self, \"accelerator\"):\n            raise AttributeError(\n                \"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`.\"\n            )\n\n        if self.ref_model is None:\n            if self.is_peft_model and not hasattr(\n                    self.accelerator.unwrap_model(self.model).pretrained_model,\n                    \"disable_adapter\",\n            ):\n                raise ValueError(\n                    \"You are using a `peft` version that does not support `disable_adapter`. Please update your `peft` version to the latest version.\"\n                )\n        else:\n            self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)\n\n    def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]:\n        \"\"\"Concatenate the chosen and rejected inputs into a single tensor.\n\n        Args:\n            batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).\n\n        Returns:\n            A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.\n        \"\"\"\n        max_length = max(batch[\"chosen_input_ids\"].shape[1], batch[\"rejected_input_ids\"].shape[1])\n        concatenated_batch = {}\n        for k in batch:\n            if k.startswith(\"chosen\") and isinstance(batch[k], torch.Tensor):\n                pad_value = self.label_pad_token_id if \"labels\" in k else self.padding_value\n                concatenated_key = k.replace(\"chosen\", \"concatenated\")\n                concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)\n        for k in batch:\n            if k.startswith(\"rejected\") and isinstance(batch[k], torch.Tensor):\n                pad_value = self.label_pad_token_id if \"labels\" in k else self.padding_value\n                concatenated_key = k.replace(\"rejected\", \"concatenated\")\n                concatenated_batch[concatenated_key] = torch.cat(\n                    (\n                        concatenated_batch[concatenated_key],\n                        pad_to_length(batch[k], max_length, pad_value=pad_value),\n                    ),\n                    dim=0,\n                ).to(self.accelerator.device)\n        return concatenated_batch\n\n    def dpo_loss(\n            self,\n            policy_chosen_logps: torch.FloatTensor,\n            policy_rejected_logps: torch.FloatTensor,\n            reference_chosen_logps: torch.FloatTensor,\n            reference_rejected_logps: torch.FloatTensor,\n            reference_free: bool = False,\n    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:\n        \"\"\"Compute the DPO loss for a batch of policy and reference model log probabilities.\n\n        Args:\n            policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)\n            policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)\n            reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)\n            reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)\n            beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0.\n            reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.\n\n        Returns:\n            A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).\n            The losses tensor contains the DPO loss for each example in the batch.\n            The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.\n        \"\"\"\n        pi_logratios = policy_chosen_logps - policy_rejected_logps\n        ref_logratios = reference_chosen_logps - reference_rejected_logps\n\n        if reference_free:\n            ref_logratios = 0\n\n        logits = pi_logratios - ref_logratios\n\n        losses = -F.logsigmoid(self.beta * logits)\n        chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()\n        rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()\n\n        return losses, chosen_rewards, rejected_rewards\n\n    def _get_batch_logps(\n            self,\n            logits: torch.FloatTensor,\n            labels: torch.LongTensor,\n            average_log_prob: bool = False,\n    ) -> torch.FloatTensor:\n        \"\"\"Compute the log probabilities of the given labels under the given logits.\n\n        Args:\n            logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)\n            labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)\n            average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.\n\n        Returns:\n            A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.\n        \"\"\"\n        if logits.shape[:-1] != labels.shape:\n            raise ValueError(\"Logits (batch and sequence length dim) and labels must have the same shape.\")\n\n        labels = labels[:, 1:].clone()\n        logits = logits[:, :-1, :]\n        loss_mask = labels != self.label_pad_token_id\n\n        # dummy token; we'll ignore the losses on these tokens later\n        labels[labels == self.label_pad_token_id] = 0\n\n        per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)\n\n        if average_log_prob:\n            return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)\n        else:\n            return (per_token_logps * loss_mask).sum(-1)\n\n    def concatenated_forward(\n            self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]\n    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:\n        \"\"\"Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.\n\n        We do this to avoid doing two forward passes, because it's faster for FSDP.\n        \"\"\"\n        concatenated_batch = self.concatenated_inputs(batch)\n        all_logits = model(\n            concatenated_batch[\"concatenated_input_ids\"],\n            attention_mask=concatenated_batch[\"concatenated_attention_mask\"],\n        ).logits.to(torch.float32)\n        all_logps = self._get_batch_logps(\n            all_logits,\n            concatenated_batch[\"concatenated_labels\"],\n            average_log_prob=False,\n        )\n        chosen_logps = all_logps[: batch[\"chosen_input_ids\"].shape[0]]\n        rejected_logps = all_logps[batch[\"chosen_input_ids\"].shape[0] :]\n\n        chosen_logits = all_logits[: batch[\"chosen_input_ids\"].shape[0]]\n        rejected_logits = all_logits[batch[\"chosen_input_ids\"].shape[0] :]\n        return chosen_logps, rejected_logps, chosen_logits, rejected_logits\n\n    def separate_forward(\n            self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]\n    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:\n        \"\"\"Run the given model on the given batch of inputs, separately process chosen and rejected inputs.\n\n        \"\"\"\n        chosen_logits = model(\n            batch[\"chosen_input_ids\"],\n            attention_mask=batch.get(\"chosen_attention_mask\", None),\n        ).logits.to(torch.float32)\n        chosen_logps = self._get_batch_logps(\n            chosen_logits,\n            batch[\"chosen_labels\"],\n            average_log_prob=False,\n        )\n        rejected_logits = model(\n            batch[\"rejected_input_ids\"],\n            attention_mask=batch.get(\"rejected_attention_mask\", None),\n        ).logits.to(torch.float32)\n        rejected_logps = self._get_batch_logps(\n            rejected_logits,\n            batch[\"rejected_labels\"],\n            average_log_prob=False,\n        )\n        return chosen_logps, rejected_logps, chosen_logits, rejected_logits\n\n    def get_batch_metrics(\n            self,\n            model,\n            batch: Dict[str, Union[List, torch.LongTensor]],\n            train_eval: Literal[\"train\", \"eval\"] = \"train\",\n    ):\n        \"\"\"Compute the DPO loss and other metrics for the given batch of inputs for train or test.\"\"\"\n        metrics = {}\n\n        (\n            policy_chosen_logps,\n            policy_rejected_logps,\n            policy_chosen_logits,\n            policy_rejected_logits,\n        ) = self.separate_forward(model, batch)\n        # ) = self.concatenated_forward(model, batch)\n        if self.logps is not None:\n            reference_chosen_logps = []\n            reference_rejected_logps = []\n            for idx in batch['index'].detach().cpu().tolist():\n                reference_chosen_logps.append(self.logps[train_eval][idx]['chosen_logp'])\n                reference_rejected_logps.append(self.logps[train_eval][idx]['rejected_logp'])\n            reference_chosen_logps = torch.stack(reference_chosen_logps).to(self.accelerator.device)\n            reference_rejected_logps = torch.stack(reference_rejected_logps).to(self.accelerator.device)\n        else:\n            with torch.no_grad():\n                if self.ref_model is None:\n                    assert self.is_peft_model, \"When ref_model is None, the policy must be a peft model\"\n                    with self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter():\n                        (\n                            reference_chosen_logps,\n                            reference_rejected_logps,\n                            _,\n                            _,\n                        ) = self.separate_forward(self.model, batch)\n                        # ) = self.concatenated_forward(self.model, batch)\n                else:\n                    (\n                        reference_chosen_logps,\n                        reference_rejected_logps,\n                        _,\n                        _,\n                    ) = self.separate_forward(self.ref_model, batch)\n                    # ) = self.concatenated_forward(self.ref_model, batch)\n\n        losses, chosen_rewards, rejected_rewards = self.dpo_loss(\n            policy_chosen_logps,\n            policy_rejected_logps,\n            reference_chosen_logps,\n            reference_rejected_logps,\n        )\n        reward_accuracies = (chosen_rewards > rejected_rewards).float()\n\n        prefix = \"eval_\" if train_eval == \"eval\" else \"\"\n        metrics[f\"{prefix}rewards/chosen\"] = chosen_rewards.cpu().numpy().mean()\n        metrics[f\"{prefix}rewards/rejected\"] = rejected_rewards.cpu().numpy().mean()\n        metrics[f\"{prefix}rewards/accuracies\"] = reward_accuracies.cpu().numpy().mean()\n        metrics[f\"{prefix}rewards/margins\"] = (chosen_rewards - rejected_rewards).cpu().numpy().mean()\n        metrics[f\"{prefix}logps/rejected\"] = policy_rejected_logps.detach().cpu().numpy().mean()\n        metrics[f\"{prefix}logps/chosen\"] = policy_chosen_logps.detach().cpu().numpy().mean()\n        metrics[f\"{prefix}logits/rejected\"] = policy_rejected_logits.detach().cpu().numpy().mean()\n        metrics[f\"{prefix}logits/chosen\"] = policy_chosen_logits.detach().cpu().numpy().mean()\n\n        return losses.mean(), metrics\n\n    def compute_loss(\n            self,\n            model: Union[PreTrainedModel, nn.Module],\n            inputs: Dict[str, Union[torch.Tensor, Any]],\n            return_outputs=False,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:\n        loss, metrics = self.get_batch_metrics(model, inputs, train_eval=\"train\")\n\n        # force log the metrics\n        if self.accelerator.is_main_process:\n            self.store_metrics(metrics, train_eval=\"train\")\n\n        if return_outputs:\n            return (loss, metrics)\n        return loss\n\n    def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:\n        \"\"\"Generate samples from the model and reference model for the given batch of inputs.\"\"\"\n\n        policy_output = model.generate(\n            batch[\"prompt_input_ids\"],\n            attention_mask=batch[\"prompt_attention_mask\"],\n            max_length=self.config.max_length,\n            do_sample=True,\n            pad_token_id=self.tokenizer.pad_token_id,\n        )\n\n        if self.ref_model is None:\n            with self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter():\n                reference_output = self.model.generate(\n                    batch[\"prompt_input_ids\"],\n                    attention_mask=batch[\"prompt_attention_mask\"],\n                    max_length=self.config.max_length,\n                    do_sample=True,\n                    pad_token_id=self.tokenizer.pad_token_id,\n                )\n        else:\n            reference_output = self.ref_model.generate(\n                batch[\"prompt_input_ids\"],\n                attention_mask=batch[\"prompt_attention_mask\"],\n                max_length=self.config.max_length,\n                do_sample=True,\n                pad_token_id=self.tokenizer.pad_token_id,\n            )\n\n        policy_output = pad_to_length(policy_output, self.config.max_length, self.tokenizer.pad_token_id)\n        policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)\n\n        reference_output = pad_to_length(reference_output, self.config.max_length, self.tokenizer.pad_token_id)\n        reference_output_decoded = self.tokenizer.batch_decode(reference_output, skip_special_tokens=True)\n\n        return policy_output_decoded, reference_output_decoded\n\n    def prediction_step(\n            self,\n            model: Union[PreTrainedModel, nn.Module],\n            inputs: Dict[str, Union[torch.Tensor, Any]],\n            prediction_loss_only: bool,\n            ignore_keys: Optional[List[str]] = None,\n    ):\n        if ignore_keys is None:\n            if hasattr(model, \"config\"):\n                ignore_keys = getattr(model.config, \"keys_to_ignore_at_inference\", [])\n            else:\n                ignore_keys = []\n\n        with torch.no_grad():\n            loss, metrics = self.get_batch_metrics(model, inputs, train_eval=\"eval\")\n\n        # force log the metrics\n        if self.accelerator.is_main_process:\n            self.store_metrics(metrics, train_eval=\"eval\")\n\n        if prediction_loss_only:\n            return loss.detach(), None, None\n\n        # logits for the chosen and rejected samples from model\n        logits_dict = {\n            \"logits_test/chosen\": metrics[\"eval_logits/chosen\"],\n            \"logits_test/rejected\": metrics[\"eval_logits/rejected\"],\n        }\n        logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)\n        logits = torch.stack(logits).mean(axis=1)\n        labels = torch.zeros(logits.shape[0])\n\n        return loss.detach(), logits, labels\n\n    def store_metrics(self, metrics: Dict[str, float], train_eval: Literal[\"train\", \"eval\"] = \"train\") -> None:\n        for key, value in metrics.items():\n            self._stored_metrics[train_eval][key].append(value)\n\n    def log(self, logs: Dict[str, float]) -> None:\n        \"\"\"\n        Log `logs` on the various objects watching training, including stored metrics.\n\n        Args:\n            logs (`Dict[str, float]`):\n                The values to log.\n        \"\"\"\n        # logs either has 'loss' or 'eval_loss'\n        train_eval = \"train\" if \"loss\" in logs else \"eval\"\n        # Add averaged stored metrics to logs\n        for key, metrics in self._stored_metrics[train_eval].items():\n            logs[key] = torch.tensor(metrics).mean().item()\n        del self._stored_metrics[train_eval]\n        return super().log(logs)\n"
  },
  {
    "path": "src/pretrain.py",
    "content": "\nimport sys\nsys.path.insert(0, \"/root/autodl-tmp/Code/RLHF\")\nsys.path.insert(0, \"/mnt/sfevol775196/sunzeye273/Code/chatgpt\")\n# sys.path.insert(0, \"/mnt/share-pa002-vol682688-prd/sunzeye273/Code/chatgpt\")\nsys.path.insert(0, \"/mnt/pa002-28359-vol543625-private/Code/chatgpt\")\nimport os\nimport argparse\nimport torch\nimport evaluate\nimport json\n\nfrom tqdm import tqdm\nfrom transformers import (\n    Trainer,\n    TrainingArguments,\n    default_data_collator,\n    GenerationConfig,\n)\n\nfrom src.utils import RESOURCE_PATH, load_tokenizer_and_model, load_checkpoint\nfrom src.data.data import PretrainDataset, chatglm2_encode, chatglm3_encode\nfrom src.utils.file_utils import set_seed, print_rank_0\n# from src.models.llama import LlamaForCausalLM\n\n\n# Create a preprocessing function to extract out the proper logits from the model output\ndef preprocess_logits_for_metrics(logits, labels):\n    if isinstance(logits, tuple):\n        logits = logits[0]\n\n    return logits.argmax(dim=-1)\n\n\ndef get_parser():\n    parser = argparse.ArgumentParser()\n    \n    parser.add_argument(\"--data_dir\", type=str, required=True)\n    parser.add_argument(\"--output_dir\", type=str, required=True)\n    parser.add_argument(\"--model_name_or_path\", type=str, required=True)\n\n    parser.add_argument(\"--seed\", type=int, default=42)\n    parser.add_argument(\"--local_rank\", type=int, default=0)\n    parser.add_argument(\"--bits\", type=int, default=32)\n    parser.add_argument(\"--device_map\", type=str, default=None, help=\"device map to allocate model,\"\n                                                                     \"[None] means cpu\"\n                                                                     \"[0, 1, 2, ...], number means single-card\"\n                                                                     \"[auto, balanced, balanced_low_0] means multi-card\")\n    parser.add_argument(\"--low_cpu_mem_usage\", action=\"store_true\", help=\"whether to enable low cpu memory usage\"\n                                                                         \"when loading model\")\n    parser.add_argument(\"--max_length\", type=int, default=1024)\n    # train\n    parser.add_argument(\"--do_train\", action=\"store_true\")\n    parser.add_argument(\"--train_filename\", type=str, default=None)\n    parser.add_argument(\"--concat_samples\", action=\"store_true\")\n    parser.add_argument(\"--num_epochs\", type=int, default=1)\n    parser.add_argument(\"--learning_rate\", type=float, default=1e-6)\n    parser.add_argument(\"--lr_scheduler_type\", type=str, default=\"cosine\",\n                        help=\"transformers.trainer_utils.SchedulerType, including:\"\n                             \"linear, cosine, cosine_with_restarts, polynomial, constant,\"\n                             \"constant_with_warmup\")\n    parser.add_argument(\"--train_batch_size\", type=int, default=4)\n    parser.add_argument(\"--weight_decay\", type=float, default=0.1)\n    parser.add_argument(\"--warmup_ratio\", type=int, default=0.1)\n    parser.add_argument(\"--logging_steps\", type=int, default=100)\n    parser.add_argument(\"--save_strategy\", type=str, default=\"steps\",\n                        help='- `\"no\"`: No save is done during training.'\n                             '- `\"epoch\"`: Save is done at the end of each epoch.'\n                             '- `\"steps\"`: Save is done every `save_steps`.')\n    parser.add_argument(\"--save_steps\", type=int, default=1000)\n    parser.add_argument(\"--save_total_limit\", type=int, default=2)\n    parser.add_argument(\"--metric_for_best_model\", type=str, default=None)\n    parser.add_argument(\"--gradient_accumulation_steps\", type=int, default=8)\n    parser.add_argument(\"--gradient_checkpointing\", action=\"store_true\",\n                        help=\"If True, use gradient checkpointing to save memory at the expense of slower backward pass.\")\n    parser.add_argument(\"--deepspeed_config\", type=str, default=None)\n    parser.add_argument(\"--lora_rank\", type=int, default=0)\n    parser.add_argument(\"--lora_alpha\", type=int, default=16)\n    parser.add_argument(\"--lora_train_bias\", type=str, default=\"none\")\n    # eval\n    parser.add_argument(\"--do_eval\", action=\"store_true\")\n    parser.add_argument(\"--eval_filename\", type=str, default=None)\n    parser.add_argument(\"--eval_batch_size\", type=int, default=4)\n    parser.add_argument(\"--evaluation_strategy\", type=str, default=\"steps\",\n                        help='- `\"no\"`: No evaluation is done during training.'\n                             '- `\"steps\"`: Evaluation is done (and logged) every `eval_steps`.'\n                             '- `\"epoch\"`: Evaluation is done at the end of each epoch.')\n    parser.add_argument(\"--eval_steps\", type=int, default=100)\n    parser.add_argument(\"--eval_accumulation_steps\", type=int, default=1)\n    # pred\n    parser.add_argument(\"--do_pred\", action=\"store_true\")\n    parser.add_argument(\"--max_length_generation\", type=int, default=None)\n    parser.add_argument(\"--checkpoint\", type=str, default=None)\n    parser.add_argument(\"--test_filename\", type=str, default=None)\n    parser.add_argument(\"--output_filename\", type=str, default=None)\n    parser.add_argument(\"--data_types\", type=str, default=None)\n    parser.add_argument(\"--do_sample\", action=\"store_true\")\n    parser.add_argument(\"--num_return_sequences\", type=int, default=1)\n    parser.add_argument(\"--top_k\", type=int, default=10)\n    parser.add_argument(\"--top_p\", type=float, default=0.9)\n    parser.add_argument(\"--temperature\", type=float, default=0.8)\n    parser.add_argument(\"--repetition_penalty\", type=float, default=1.0)\n\n    args = parser.parse_args()\n    \n    return args\n\n\ndef main():\n    args = get_parser()\n    print_rank_0(f\"Parameters: {args}\")\n\n    set_seed(args.seed)\n\n    # load tokenizer and model\n    tokenizer, model, eos_token_id = load_tokenizer_and_model(args)\n    if \"baichuan2\" in args.model_name_or_path.lower() or \"qwen\" in args.model_name_or_path.lower():\n        generation_config = GenerationConfig.from_pretrained(args.model_name_or_path, trust_remote_code=True)\n        generation_config.do_sample = args.do_sample\n        generation_config.max_new_tokens = args.max_length_generation\n        model.generation_config = generation_config\n\n    if args.checkpoint is not None:\n        load_checkpoint(args, model)\n\n    print_rank_0(f\"Finished loading model and tokenizer\")\n\n    # Set up the datasets\n    if args.do_train:\n        train_dataset = PretrainDataset(args, os.path.join(args.data_dir, args.train_filename),\n                                        tokenizer, concat_samples=args.concat_samples)\n    else:\n        train_dataset = None\n    if args.do_eval:\n        dev_dataset = PretrainDataset(args, os.path.join(args.data_dir, args.eval_filename),\n                                      tokenizer, concat_samples=False)\n    else:\n        dev_dataset = None\n    # if args.do_pred:\n    #     test_dataset = PretrainDataset(args, os.path.join(args.data_dir, args.test_filename),\n    #                                    tokenizer)\n    # else:\n    #     test_dataset = None\n\n    if args.do_train:\n        if torch.cuda.is_available():\n            bf16 = torch.cuda.get_device_capability()[0] >= 8\n            fp16 = not bf16\n        else:\n            fp16 = False\n            bf16 = False\n        # training arguments\n        deepspeed_config = os.path.join(RESOURCE_PATH, \"config\", \"deepspeed\", args.deepspeed_config) if args.deepspeed_config is not None else None\n        training_args = TrainingArguments(\n            output_dir=args.output_dir,\n            no_cuda=not torch.cuda.is_available(),\n            seed=args.seed,\n            data_seed=args.seed,\n            local_rank=args.local_rank,\n            do_train=args.do_train,\n            num_train_epochs=args.num_epochs,\n            learning_rate=args.learning_rate,\n            lr_scheduler_type=args.lr_scheduler_type,\n            per_device_train_batch_size=args.train_batch_size,\n            gradient_accumulation_steps=args.gradient_accumulation_steps,\n            warmup_ratio=args.warmup_ratio,\n            weight_decay=args.weight_decay,\n            half_precision_backend=\"auto\",\n            fp16=fp16,\n            bf16=bf16,\n            adam_beta1=0.9,\n            adam_beta2=0.95,\n            save_strategy=args.save_strategy,\n            save_steps=args.save_steps,\n            save_total_limit=args.save_total_limit,\n            metric_for_best_model=args.metric_for_best_model,\n            greater_is_better=True,\n            logging_steps=args.logging_steps,\n            report_to=[\"tensorboard\"],\n            deepspeed=deepspeed_config,\n            gradient_checkpointing=args.gradient_checkpointing,\n            do_eval=args.do_eval,\n            evaluation_strategy=args.evaluation_strategy,\n            eval_steps=args.eval_steps,\n            eval_accumulation_steps=args.eval_accumulation_steps,\n            per_device_eval_batch_size=args.eval_batch_size,\n            # do_predict=args.do_pred,\n            # use_legacy_prediction_loop=args.do_pred,\n        )\n        print_rank_0(f\"Training Arguments: {training_args}\")\n\n        # Set up the metric\n        rouge = evaluate.load(\"rouge\")\n\n        def compute_metrics(eval_preds):\n            labels_ids = eval_preds.label_ids\n            pred_ids = eval_preds.predictions\n            pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n            label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)\n            result = rouge.compute(predictions=pred_str, references=label_str)\n\n            return result\n\n        # Prepare the trainer and start training\n        trainer = Trainer(\n            model=model,\n            args=training_args,\n            train_dataset=train_dataset,\n            eval_dataset=dev_dataset,\n            compute_metrics=compute_metrics,\n            data_collator=default_data_collator,\n            preprocess_logits_for_metrics=preprocess_logits_for_metrics,\n        )\n        # model.config.use_cache = False  # silence the warnings. Please re-enable for inference!\n\n        trainer.train()\n        trainer.save_model(args.output_dir)\n\n    elif args.do_eval:\n        # res = trainer.evaluate(eval_dataset=dev_dataset)\n        # logger.info(res)\n        pass\n\n    if args.do_pred:\n        device = f\"cuda:{args.local_rank}\" if torch.cuda.is_available() and args.device_map is not None else \"cpu\"\n        model.eval()\n        # tokenizer.padding_side = \"left\"\n        if args.test_filename is not None:\n            test_file = args.test_filename\n        else:\n            test_file = os.path.join(RESOURCE_PATH, \"test_prompts.jsonl\")\n        data_types = args.data_types.split(\",\") if args.data_types is not None else None\n        with torch.no_grad():\n            if args.output_filename is not None:\n                w = open(os.path.join(args.output_dir, args.output_filename), \"w\", encoding=\"utf-8\")\n            else:\n                w = None\n            for line in tqdm(open(test_file, \"r\", encoding=\"utf-8\"), desc=\"Prediction\"):\n            # for line in open(test_file, \"r\", encoding=\"utf-8\"):\n                test_data = json.loads(line.strip(\"\\n\"))\n                data_type = test_data.get('data_type', None)\n                if data_types is not None and data_type not in data_types:\n                    continue\n                prompt = test_data['prompt']\n                prefix = test_data.get('prefix', None)\n                system = test_data.get('system', \"\")\n                label = test_data.get('label', None)\n                # encoded_prompt = tokenizer(prompt)\n                if \"chatglm3\" in args.model_name_or_path.lower():\n                    _, _, prompt_ids = chatglm3_encode(tokenizer, query=prompt, label=None,\n                                                       system=system, max_length=args.max_length, is_prefix=True)\n                    input_ids = torch.tensor([prompt_ids], dtype=torch.long, device=device)\n                    outputs = model.generate(input_ids=input_ids,\n                                             max_new_tokens=args.max_length_generation,\n                                             eos_token_id=eos_token_id,\n                                             pad_token_id=tokenizer.pad_token_id,\n                                             do_sample=args.do_sample,\n                                             num_return_sequences=args.num_return_sequences,\n                                             top_k=args.top_k,\n                                             top_p=args.top_p,\n                                             temperature=args.temperature)\n                    prompt_length = len(prompt_ids)\n                    results = tokenizer.batch_decode([output[prompt_length:] for output in outputs], skip_special_tokens=True)\n                elif \"chatglm2\" in args.model_name_or_path.lower():\n                    # results, history = model.chat(tokenizer, prompt, history=None, do_sample=False,\n                    #                               max_new_tokens=args.max_length_generation)\n                    # results = [results]\n                    _, _, prompt_ids = chatglm2_encode(tokenizer, query=prompt, label=None,\n                                                       system=system, max_length=args.max_length, is_prefix=True)\n                    input_ids = torch.tensor([prompt_ids], dtype=torch.long, device=device)\n                    outputs = model.generate(input_ids=input_ids,\n                                             max_new_tokens=args.max_length_generation,\n                                             eos_token_id=eos_token_id,\n                                             pad_token_id=tokenizer.pad_token_id,\n                                             do_sample=args.do_sample,\n                                             num_return_sequences=args.num_return_sequences,\n                                             top_k=args.top_k,\n                                             top_p=args.top_p,\n                                             temperature=args.temperature)\n                    prompt_length = len(prompt_ids)\n                    results = tokenizer.batch_decode([output[prompt_length:] for output in outputs], skip_special_tokens=True)\n                elif \"chatglm\" in args.model_name_or_path.lower():\n                    inputs = tokenizer(prompt, max_length=args.max_length-args.max_length_generation,\n                                       truncation=\"only_first\",\n                                       return_tensors=\"pt\")\n                    inputs = inputs.to(device)\n                    outputs = model.generate(inputs=inputs['input_ids'],\n                                             max_new_tokens=args.max_length_generation,\n                                             eos_token_id=eos_token_id,\n                                             pad_token_id=tokenizer.pad_token_id,\n                                             do_sample=args.do_sample,\n                                             num_return_sequences=args.num_return_sequences,\n                                             top_k=args.top_k,\n                                             top_p=args.top_p,\n                                             temperature=args.temperature,\n                                             repetition_penalty=args.repetition_penalty)\n                    prompt_length = len(inputs['input_ids'][0])\n                    results = tokenizer.batch_decode([output[prompt_length:] for output in outputs], skip_special_tokens=True)\n                # elif \"glm\" in args.model_name_or_path.lower():\n                #     encoded_prompt = tokenizer(prompt, prefix + tokenizer.mask_token)\n                #     prompt_length = len(encoded_prompt['input_ids'])\n                #     encoded_dict = tokenizer(prompt, prefix + tokenizer.mask_token,\n                #                              max_length=min(prompt_length, args.max_length),\n                #                              truncation=\"only_first\",\n                #                              return_tensors=\"pt\",\n                #                              return_token_type_ids=False)\n                #     max_gen_length = args.max_length - encoded_dict['input_ids'].shape[1]\n                #     inputs = tokenizer.build_inputs_for_generation(encoded_dict,\n                #                                                    max_gen_length=max_gen_length, padding=True)\n                #     inputs = inputs.to(device)\n                #     outputs = model.generate(**inputs,\n                #                              max_new_tokens=min(args.max_length_generation, max_gen_length),\n                #                              eos_token_id=tokenizer.eop_token_id,\n                #                              pad_token_id=tokenizer.pad_token_id,\n                #                              do_sample=args.do_sample,\n                #                              num_return_sequences=args.num_return_sequences,\n                #                              top_k=args.top_k,\n                #                              top_p=args.top_p,\n                #                              temperature=args.temperature)\n                elif \"baichuan2\" in args.model_name_or_path.lower():\n                    messages = [{\"role\": \"user\", \"content\": prompt}]\n                    results = model.chat(tokenizer, messages)\n                    results = [results]\n                elif \"qwen\" in args.model_name_or_path.lower():\n                    results, history = model.chat(tokenizer, prompt, history=None)\n                    results = [results]\n                else:\n                    if prefix is not None and len(prefix) > 0:\n                        prompt += prefix\n                    inputs = tokenizer(prompt, max_length=args.max_length-args.max_length_generation,\n                                       truncation=\"only_first\",\n                                       return_tensors=\"pt\")\n                    inputs = inputs.to(device)\n                    outputs = model.generate(inputs=inputs['input_ids'],\n                                             max_new_tokens=args.max_length_generation,\n                                             eos_token_id=eos_token_id,\n                                             do_sample=args.do_sample,\n                                             num_return_sequences=args.num_return_sequences,\n                                             top_k=args.top_k,\n                                             top_p=args.top_p,\n                                             temperature=args.temperature,\n                                             repetition_penalty=args.repetition_penalty)\n                    prompt_length = len(inputs['input_ids'][0])\n                    results = tokenizer.batch_decode([output[prompt_length:] for output in outputs], skip_special_tokens=True)\n                # p = tokenizer.decode(encoded_prompt['input_ids'], skip_special_tokens=True)\n                answers = []\n                for r in results:\n                    # answer = r.replace(p, \"\").strip()\n                    answer = r\n                    # print_rank_0(f\"\\nprompt: {prompt}\\nanswer: {answer}\")\n                    answers.append({\"answer\": answer, \"score\": None})\n                if w is not None:\n                    w.write(json.dumps({\"prompt\": prompt, \"prefix\": prefix, \"system\": system, \"answers\": answers,\n                                        \"label\": label}, ensure_ascii=False)+\"\\n\")\n\n            if w is not None:\n                w.close()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/pretrain_wo_trainer.py",
    "content": "\nimport sys\nsys.path.insert(0, \"/root/autodl-tmp/Code/RLHF\")\nsys.path.insert(0, \"/mnt/sfevol775196/sunzeye273/Code/chatgpt\")\n# sys.path.insert(0, \"/mnt/share-pa002-vol682688-prd/sunzeye273/Code/chatgpt\")\nsys.path.insert(0, \"/mnt/pa002-28359-vol543625-private/Code/chatgpt\")\nimport os\nimport argparse\nimport torch\nimport evaluate\nimport json\nimport numpy as np\nimport deepspeed\n\nfrom datetime import datetime\nfrom torch.utils.data import DistributedSampler, DataLoader\nfrom transformers.deepspeed import HfDeepSpeedConfig\n# from deepspeed.ops.adam import FusedAdam\n# from deepspeed.ops.adam import DeepSpeedCPUAdam\n\nfrom src.utils import logger, RESOURCE_PATH, load_tokenizer_and_model, load_checkpoint\nfrom src.data.data import PretrainDataset\nfrom src.utils.file_utils import set_seed, print_gpu_utilization, print_rank_0\nfrom src.utils.modeling_utils import rotate_checkpoints, save_zero_three_model\n# from src.models import convert_to_lora_recursively\n# from src.models.llama import LlamaForCausalLM\n\n\n# Create a preprocessing function to extract out the proper logits from the model output\ndef preprocess_logits_for_metrics(logits, labels):\n    if isinstance(logits, tuple):\n        logits = logits[0]\n\n    return logits.argmax(dim=-1)\n\n\ndef get_parser():\n    parser = argparse.ArgumentParser()\n    \n    parser.add_argument(\"--data_dir\", type=str, required=True)\n    parser.add_argument(\"--output_dir\", type=str, required=True)\n    parser.add_argument(\"--model_name_or_path\", type=str, required=True)\n\n    parser.add_argument(\"--seed\", type=int, default=42)\n    parser.add_argument(\"--local_rank\", type=int, default=0)\n    parser.add_argument(\"--max_length\", type=int, default=1024)\n    parser.add_argument(\"--max_length_generation\", type=int, default=None)\n    parser.add_argument(\"--bits\", type=int, default=32)\n    parser.add_argument(\"--device_map\", type=str, default=None, help=\"device map to allocate model,\"\n                                                                     \"[None] means cpu\"\n                                                                     \"[0, 1, 2, ...], number means single-card\"\n                                                                     \"[auto, balanced, balanced_low_0] means multi-card\")\n    parser.add_argument(\"--low_cpu_mem_usage\", action=\"store_true\", help=\"whether to enable low cpu memory usage\"\n                                                                         \"when loading model\")\n    # train\n    parser.add_argument(\"--do_train\", action=\"store_true\")\n    parser.add_argument(\"--train_filename\", type=str, default=None)\n    parser.add_argument(\"--concat_samples\", action=\"store_true\")\n    parser.add_argument(\"--num_epochs\", type=int, default=1)\n    parser.add_argument(\"--learning_rate\", type=float, default=1e-6)\n    parser.add_argument(\"--lr_scheduler_type\", type=str, default=\"OneCycle\",\n                        help=\"deepspeed scheduler types, including:\"\n                             \"LRRangeTest, OneCycle, WarmupLR, WarmupDecayLR\")\n    parser.add_argument(\"--train_batch_size\", type=int, default=4)\n    parser.add_argument(\"--weight_decay\", type=float, default=0.1)\n    parser.add_argument(\"--warmup_steps\", type=int, default=100)\n    parser.add_argument(\"--logging_steps\", type=int, default=100)\n    parser.add_argument(\"--save_strategy\", type=str, default=\"steps\",\n                        help='- `\"no\"`: No save is done during training.'\n                             '- `\"epoch\"`: Save is done at the end of each epoch.'\n                             '- `\"steps\"`: Save is done every `save_steps`.')\n    parser.add_argument(\"--save_steps\", type=int, default=1000)\n    parser.add_argument(\"--save_total_limit\", type=int, default=2)\n    parser.add_argument(\"--metric_for_best_model\", type=str, default=None)\n    parser.add_argument(\"--max_grad_norm\", type=float, default=1.0)\n    parser.add_argument(\"--gradient_accumulation_steps\", type=int, default=8)\n    parser.add_argument(\"--gradient_checkpointing\", action=\"store_true\",\n                        help=\"If True, use gradient checkpointing to save memory at the expense of slower backward pass.\")\n    parser.add_argument(\"--deepspeed_config\", type=str, default=None)\n    parser.add_argument(\"--lora_rank\", type=int, default=0)\n    parser.add_argument(\"--lora_alpha\", type=int, default=16)\n    parser.add_argument(\"--lora_train_bias\", type=str, default=\"none\")\n    # eval\n    parser.add_argument(\"--do_eval\", action=\"store_true\")\n    parser.add_argument(\"--eval_filename\", type=str, default=None)\n    parser.add_argument(\"--eval_batch_size\", type=int, default=4)\n    parser.add_argument(\"--evaluation_strategy\", type=str, default=\"steps\",\n                        help='- `\"no\"`: No evaluation is done during training.'\n                             '- `\"steps\"`: Evaluation is done (and logged) every `eval_steps`.'\n                             '- `\"epoch\"`: Evaluation is done at the end of each epoch.')\n    parser.add_argument(\"--eval_steps\", type=int, default=100)\n    parser.add_argument(\"--eval_accumulation_steps\", type=int, default=1)\n    # pred\n    parser.add_argument(\"--do_pred\", action=\"store_true\")\n    parser.add_argument(\"--checkpoint\", type=str, default=None)\n    parser.add_argument(\"--test_filename\", type=str, default=None)\n    parser.add_argument(\"--output_filename\", type=str, default=None)\n    parser.add_argument(\"--do_sample\", action=\"store_true\")\n    parser.add_argument(\"--num_return_sequences\", type=int, default=1)\n    parser.add_argument(\"--top_k\", type=int, default=10)\n    parser.add_argument(\"--top_p\", type=float, default=0.9)\n    parser.add_argument(\"--temperature\", type=float, default=0.8)\n    parser.add_argument(\"--repetition_penalty\", type=float, default=1.0)\n\n    args = parser.parse_args()\n    \n    return args\n\n\ndef pred_single_sample(prompt, prefix, model, tokenizer, args, device, eos_token_id):\n    max_prompt_length = args.max_length - args.max_length_generation\n    if \"chatglm\" in args.model_name_or_path.lower():\n        encoded_prompt = tokenizer(prompt)\n        prompt_length = len(encoded_prompt['input_ids'])\n        inputs = tokenizer(prompt,\n                           max_length=min(prompt_length, args.max_length),\n                           truncation=\"only_first\",\n                           return_tensors=\"pt\")\n        # max_gen_length = args.max_length - encoded_dict['input_ids'].shape[1]\n        # inputs = tokenizer.build_inputs_for_generation(encoded_dict,\n        #                                                max_gen_length=max_gen_length, padding=True)\n        input_ids = inputs['input_ids']\n        inputs = inputs.to(device)\n        outputs = model.generate(inputs=inputs['input_ids'],\n                                 max_new_tokens=args.max_length_generation,\n                                 eos_token_id=eos_token_id,\n                                 pad_token_id=tokenizer.pad_token_id,\n                                 do_sample=args.do_sample,\n                                 num_return_sequences=args.num_return_sequences,\n                                 top_k=args.top_k,\n                                 top_p=args.top_p,\n                                 temperature=args.temperature)\n    # elif \"glm\" in args.model_name_or_path.lower():\n    #     encoded_prompt = tokenizer(prompt, prefix + tokenizer.mask_token)\n    #     prompt_length = len(encoded_prompt['input_ids'])\n    #     encoded_dict = tokenizer(prompt, prefix + tokenizer.mask_token,\n    #                              max_length=min(prompt_length, args.max_length),\n    #                              truncation=\"only_first\",\n    #                              return_tensors=\"pt\",\n    #                              return_token_type_ids=False)\n    #     input_ids = encoded_dict['input_ids']\n    #     max_gen_length = args.max_length - encoded_dict['input_ids'].shape[1]\n    #     inputs = tokenizer.build_inputs_for_generation(encoded_dict,\n    #                                                    max_gen_length=max_gen_length, padding=True)\n    #     inputs = inputs.to(device)\n    #     outputs = model.generate(inputs=inputs['input_ids'],\n    #                              max_new_tokens=min(args.max_length_generation, max_gen_length),\n    #                              eos_token_id=tokenizer.eop_token_id,\n    #                              pad_token_id=tokenizer.pad_token_id,\n    #                              do_sample=args.do_sample,\n    #                              num_return_sequences=args.num_return_sequences,\n    #                              top_k=args.top_k,\n    #                              top_p=args.top_p,\n    #                              temperature=args.temperature)\n    else:\n        inputs = tokenizer(prompt, max_length=max_prompt_length, truncation=\"longest_first\", return_tensors=\"pt\")\n        input_ids = inputs['input_ids']\n        inputs = inputs.to(device)\n        outputs = model.generate(inputs=inputs['input_ids'],\n                                 max_new_tokens=args.max_length_generation,\n                                 do_sample=args.do_sample,\n                                 num_return_sequences=args.num_return_sequences,\n                                 top_k=args.top_k,\n                                 top_p=args.top_p,\n                                 temperature=args.temperature,\n                                 repetition_penalty=args.repetition_penalty)\n\n    results = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n    p = tokenizer.decode(input_ids, skip_special_tokens=True)\n    results = [result.replace(p, \"\").strip() for result in results]\n    answers = []\n    for r in results:\n        print_rank_0(f\"\\nprompt: {p}\\nanswer: {r}\")\n        answers.append({\"answer\": r, \"score\": None})\n    d = {\"prompt\": prompt, \"prefix\": prefix, \"answers\": answers}\n\n    return d\n\n\ndef pred(args, model, tokenizer, device, eos_token_id, step=-1):\n    print_rank_0(f\"Prediction Result@{step}\")\n    with torch.no_grad():\n        with open(os.path.join(args.output_dir, args.output_filename.format(step=step)), \"w\", encoding=\"utf-8\") as w:\n            with open(os.path.join(args.data_dir, args.test_filename), \"r\", encoding=\"utf-8\") as r:\n                while True:\n                    line = r.readline()\n                    if not line:\n                        break\n                    item = json.loads(line.strip(\"\\n\"))\n                    prompt = item['context']\n                    result = pred_single_sample(prompt, \"\", model, tokenizer, args, device, eos_token_id)\n                    if args.local_rank <= 0:\n                        w.write(json.dumps(result, ensure_ascii=False)+\"\\n\")\n\n\ndef main():\n    args = get_parser()\n    print_rank_0(f\"Parameters: {args}\")\n\n    set_seed(args.seed)\n\n    torch.cuda.set_device(args.local_rank)\n    device = torch.device(\"cuda\", args.local_rank)\n\n    # load quantization config\n    if torch.cuda.is_available():\n        bf16 = torch.cuda.get_device_capability()[0] >= 8\n        fp16 = not bf16\n    else:\n        fp16 = False\n        bf16 = False\n\n    # create HfDeepSpeedConfig [must be called before instantiating model]\n    if args.deepspeed_config is not None:\n        ds_config_filename = os.path.join(RESOURCE_PATH, \"config\", \"deepspeed\", args.deepspeed_config)\n        ds_config = json.load(open(ds_config_filename, \"r\", encoding=\"utf-8\"))\n        # ds_config[\"steps_per_print\"] = args.logging_steps\n        ds_config[\"train_micro_batch_size_per_gpu\"] = args.train_batch_size\n        ds_config[\"gradient_accumulation_steps\"] = args.gradient_accumulation_steps\n        ds_config[\"gradient_clipping\"] = args.max_grad_norm\n        # TODO: before calling dist init, world size is always 1, therefore ds_config['train_batch_size'] cannot multiply world size\n        ds_config['train_batch_size'] = args.train_batch_size * args.gradient_accumulation_steps #* torch.cuda.device_count()\n        # TODO: assuming hidden_size=4096\n        ds_config[\"zero_optimization\"][\"reduce_bucket_size\"] = 4096 * 4096\n        ds_config[\"zero_optimization\"][\"stage3_prefetch_bucket_size\"] = 0.9 * 4096 * 4096\n        ds_config[\"zero_optimization\"][\"stage3_param_persistence_threshold\"] = 10 * 4096\n        ds_config[\"fp16\"][\"enabled\"] = fp16\n        ds_config[\"bf16\"][\"enabled\"] = bf16\n        ds_config[\"optimizer\"][\"params\"] = {\n                \"lr\": args.learning_rate,\n                \"betas\": [0.9, 0.95],\n                \"eps\": 1e-8,\n                \"weight_decay\": args.weight_decay\n            }\n        assert ds_config[\"scheduler\"]['type'] == args.lr_scheduler_type\n        ds_config[\"scheduler\"][\"params\"] = {\n                    \"cycle_min_lr\": 0,\n                    \"cycle_max_lr\": args.learning_rate,\n                    \"cycle_first_step_size\": args.warmup_steps\n                }\n        current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')\n        ds_config['tensorboard']['job_name'] = f\"deepspeed-{current_time}\"\n        dschf = HfDeepSpeedConfig(ds_config)  # keep this object alive\n\n    # load tokenizer and model\n    tokenizer, model, eos_token_id = load_tokenizer_and_model(args, with_trainer=False)\n    print_gpu_utilization(\"after from_pretrained()\", args.local_rank)\n\n    if args.checkpoint is not None:\n        load_checkpoint(args, model)\n\n    print_rank_0(f\"Finished loading model and tokenizer\")\n\n    # Set up the datasets\n    if args.do_train:\n        train_dataset = PretrainDataset(args, os.path.join(args.data_dir, args.train_filename),\n                                        tokenizer, concat_samples=args.concat_samples)\n    else:\n        train_dataset = None\n    if args.do_eval:\n        eval_dataset = PretrainDataset(args, os.path.join(args.data_dir, args.eval_filename),\n                                      tokenizer, concat_samples=False)\n        # Set up the metric\n        rouge = evaluate.load(\"rouge\")\n\n        def compute_metrics(pred_ids, label_ids):\n            pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n            label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n            result = rouge.compute(predictions=pred_str, references=label_str)\n\n            return result\n    else:\n        eval_dataset = None\n\n    if args.do_train:\n        # # Optimizer\n        # AdamOptimizer = DeepSpeedCPUAdam if \"3\" in args.deepspeed_config else FusedAdam\n        # optim_params = get_optimizer_grouped_parameters(\n        #     actor_model, self.args.actor_weight_decay)\n        # optim = AdamOptimizer(optim_params,\n        #                       lr=self.args.actor_learning_rate,\n        #                       betas=(0.9, 0.95))\n        #\n        # # LR Scheduler\n        # lr_scheduler = get_scheduler(\n        #     name=self.args.lr_scheduler_type,\n        #     optimizer=optim,\n        #     num_warmup_steps=self.args.warmup_steps,\n        #     num_training_steps=self.num_total_iters,\n        # )\n\n        # deepspeed initialize\n        ds_config['train_batch_size'] = args.train_batch_size * args.gradient_accumulation_steps * torch.cuda.device_count()\n        model_engine, *_ = deepspeed.initialize(model=model,\n                                                # optimizer=optim,\n                                                # lr_scheduler=lr_scheduler,\n                                                config=ds_config)\n        print_gpu_utilization(\"after deepspeed.initialize()\", args.local_rank)\n\n        # create data loader\n        train_sampler = DistributedSampler(train_dataset)\n        train_dataloader = DataLoader(\n            train_dataset,\n            # collate_fn=data_collator,\n            sampler=train_sampler,\n            batch_size=args.train_batch_size)\n\n        if args.do_eval:\n            eval_sampler = DistributedSampler(eval_dataset)\n            eval_dataloader = DataLoader(\n                eval_dataset,\n                # collate_fn=data_collator,\n                sampler=eval_sampler,\n                batch_size=args.eval_batch_size)\n\n            def eval(step):\n                model_engine.eval()\n                eval_results = dict()\n                with torch.no_grad():\n                    for eval_batch in eval_dataloader:\n                        eval_batch = {k: v.to(device) for k, v in eval_batch.items()}\n                        eval_output = model_engine(**eval_batch)\n                        pred_ids = preprocess_logits_for_metrics(eval_output.logits, None)\n                        result_rouge = compute_metrics(pred_ids, eval_batch['labels'])\n                        for k, v in result_rouge.items():\n                            key = f\"eval_{k}\"\n                            if key not in eval_results:\n                                eval_results[key] = []\n                            eval_results[key].append(v)\n                        if \"eval_loss\" not in eval_results:\n                            eval_results['eval_loss'] = []\n                        eval_results['eval_loss'].append(eval_output.loss.tolist())\n                if args.do_pred:\n                    pred(args, model_engine, tokenizer, device, eos_token_id, step)\n                model_engine.train()\n                for k, v in eval_results.items():\n                    eval_results[k] = np.mean(eval_results[k])\n                return eval_results\n\n        # training\n        model_engine.train()\n        if args.gradient_checkpointing:\n            model_engine.module.gradient_checkpointing_enable()\n        print_gpu_utilization(\"before training begin\", args.local_rank)\n        global_step = 0\n        best_metric = None\n        best_model_checkpoint = None\n        if args.do_eval:\n            assert args.eval_steps <= args.save_steps and args.save_steps % args.eval_steps == 0, \\\n                f\"save steps should be greater than eval steps and be a multiple of eval steps\"\n            eval_results = eval(global_step)\n            print_rank_0(f\"Epoch-0, Gloal step-{global_step}, Evaluation result: {eval_results}\")\n            if args.metric_for_best_model is not None:\n                assert args.metric_for_best_model in eval_results, \\\n                    f\"{args.metric_for_best_model} is not a valid metric, \" \\\n                    f\"please choose from the following metrics: {eval_results.keys()}\"\n        for epoch in range(args.num_epochs):\n            print_rank_0(f\"Beginning of Epoch {epoch+1}/{args.num_epochs}\")\n            for step, batch in enumerate(train_dataloader):\n                batch = {k: v.to(device) for k, v in batch.items()}\n                logger.debug(f\"batch keys: {batch.keys()}\")\n                output = model_engine(**batch)\n                model_engine.backward(output.loss)\n                model_engine.step()\n                global_step += 1\n                if global_step % args.logging_steps == 0:\n                    print_rank_0(f\"Epoch-{epoch+1}, Gloal step-{global_step}, loss: {output.loss}\")\n                if args.do_eval and global_step % args.eval_steps == 0:\n                    eval_results = eval(global_step)\n                    print_rank_0(f\"Epoch-{epoch+1}, Gloal step-{global_step}, Evaluation result: {eval_results}\")\n                if global_step % args.save_steps == 0:\n                    output_dir = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                    if args.do_eval and args.metric_for_best_model is not None:\n                        if (\n                                best_metric is None or\n                                best_model_checkpoint is None or\n                                eval_results[args.metric_for_best_model] > best_metric\n                        ):\n                            best_metric = eval_results[args.metric_for_best_model]\n                            best_model_checkpoint = output_dir\n                    rotate_checkpoints(args.save_total_limit, use_mtime=True, output_dir=args.output_dir,\n                                       best_model_checkpoint=best_model_checkpoint)\n                    # save_zero_three_model(model_engine, args.local_rank,\n                    #                       save_dir=output_dir,\n                    #                       zero_stage=ds_config['zero_optimization']['stage'])\n                    # model_engine.save_16bit_model(output_dir)\n                    model_engine.save_checkpoint(args.output_dir, f\"checkpoint-{global_step}\")\n                    print_rank_0(f\"Finished saving checkpoint @Step-{global_step}\")\n\n        print_rank_0(f\"Finished training! epochs: {epoch+1}, steps: {global_step}\")\n        output_dir = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n        # save_zero_three_model(model_engine, args.local_rank,\n        #                       save_dir=output_dir,\n        #                       zero_stage=ds_config['zero_optimization']['stage'])\n        # model_engine.save_16bit_model(output_dir)\n        model_engine.save_checkpoint(args.output_dir, f\"checkpoint-{global_step}\")\n        print_rank_0(f\"Finished saving checkpoint @Step-{global_step}\")\n\n    elif args.do_eval:\n        pass\n\n    if args.do_pred:\n        model.eval()\n        device = f\"cuda:{args.local_rank}\" if torch.cuda.is_available() else \"cpu\"\n        model = model.to(device)\n        # tokenizer.padding_side = \"left\"\n        pred(args, model, tokenizer, device, eos_token_id)\n\n    \nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/resources/config/deepspeed/Experiment.md",
    "content": "\n为验证不同预训练模型使用deepspeed的训练效率是否能达到官方宣称的效果（加速、节省GPU等），进行了benchmarking\n- 实验场景：SFT阶段训练\n- 实验数据：SFT & Reward Data的验证集，共1万条样本\n- 实验参数：```batch_size=1, max_sequence_length=512, gradient_checkpointing=False```\n\n\n<table>\n    <thead>\n        <tr> <td>模型</td>  <td>整体耗时/epoch</td>  <td>单条样本耗时</td>  <td>内存使用量</td>  <td>显存使用量</td>  <td>GPU型号和数量</td> <td>fp16</td> <td>bf16</td> <td>deepspeed stage</td> <td>offload optimizer</td> <td>pin memory</td> <td>offloard param</td> <td>overlap comm</td> <td>allgather bucket size</td> <td>stage3 max live parameters</td> </tr>\n    </thead>\n   <tbody>\n       <tr> <td rowspan=\"15\">Pangu-350M</td>  <td>20min</td>  <td>1.17s/it</td>  <td></td>  <td>1*8750MB</td>  <td>1*V100 16G</td>  <td>false</td>  <td>-</td>  <td>-</td>  <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> </tr>\n       <tr> <td>19min</td>  <td>1.03s/it</td>  <td></td>  <td>1*9010MB</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>-</td>  <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> </tr>\n       <tr> <td>7.5min</td>  <td>1.10s/it</td>  <td></td>  <td>3*9406MB</td>  <td>3*V100 16G</td>  <td>true</td>  <td>-</td>  <td>0</td>  <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> </tr>\n       <tr> <td>6.5min</td>  <td>1.05s/it</td>  <td></td>  <td>3*5674MB</td>  <td>3*V100 16G</td>  <td>true</td>  <td>-</td>  <td>1</td>  <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> </tr>\n       <tr> <td>14min</td>  <td>2.10s/it</td>  <td></td>  <td>3*6262MB</td>  <td>3*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>false</td> <td>-</td> <td>-</td> <td>false</td> <td>5e8</td> <td>-</td> </tr>\n       <tr> <td>18min</td>  <td>2.6s/it</td>  <td>18G</td>  <td>3*3668MB</td>  <td>3*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>2e8</td> <td>-</td> </tr>\n       <tr> <td>18.5min</td>  <td>2.65s/it</td>  <td>18G</td>  <td>3*4240MB</td>  <td>3*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>5e8</td> <td>-</td> </tr>\n       <tr> <td>18.5min</td>  <td>2.65s/it</td>  <td>18G</td>  <td>3*5194MB</td>  <td>3*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>1e9</td> <td>-</td> </tr>\n       <tr> <td>19min</td>  <td>2.80s/it</td>  <td>18G</td>  <td>3*12824MB</td>  <td>3*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>5e9</td> <td>-</td> </tr>\n       <tr> <td>47min</td>  <td>6.75s/it</td>  <td>14G</td>  <td>3*4914MB</td>  <td>3*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>false</td> <td>-</td> <td>false</td> <td>false</td> <td>1e9</td> <td>-</td> </tr>\n       <tr> <td>3.3h</td>  <td>29s/it</td>  <td>18G</td>  <td>3*3512MB</td>  <td>3*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>false</td> <td>false</td> <td>1e9</td> <td>-</td> </tr>\n       <tr> <td>4h</td>  <td>34s/it</td>  <td>24G</td>  <td>3*3466MB</td>  <td>3*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>1e9</td> <td>-</td> </tr>\n       <tr> <td>3.8h</td>  <td>33s/it</td>  <td>24G</td>  <td>3*3746MB</td>  <td>3*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>true</td> <td>1e9</td> <td>-</td> </tr>\n       <tr> <td>3.8h</td>  <td>33s/it</td>  <td>24G</td>  <td>3*3594MB</td>  <td>3*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>true</td> <td>5e8</td> <td>-</td> </tr>\n       <tr> <td>4h</td>  <td>35s/it</td>  <td>24G</td>  <td>3*3526MB</td>  <td>3*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>true</td> <td>2e8</td> <td>-</td> </tr>\n       <tr> <td rowspan=\"10\">Pangu-2.6B</td>  <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>1*V100 16G</td>  <td>false</td>  <td>-</td>  <td>-</td>  <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> </tr>\n       <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>1*V100 16G</td>  <td>true</td>  <td>-</td>  <td>-</td>  <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> </tr>\n       <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>3*V100 16G</td>  <td>true</td>  <td>-</td>  <td>0</td>  <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> </tr>\n       <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>3*V100 16G</td>  <td>true</td>  <td>-</td>  <td>1</td>  <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> </tr>\n       <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>3*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>false</td> <td>-</td> <td>-</td> <td>false</td> <td>5e8</td> <td>-</td> </tr>\n       <tr> <td>1.5h</td>  <td>12.3s/it</td>  <td>59G</td>  <td>3*10796MB</td>  <td>3*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>2e8</td> <td>-</td> </tr>\n       <tr> <td>1.3h</td>  <td>12s/it</td>  <td>59G</td>  <td>3*11368MB</td>  <td>3*V100 16G</td>  <td>true</td>  <td>-</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>5e8</td> <td>-</td> </tr>\n       <tr> <td>-</td>  <td>-</td>  <td>-</td>  <td>OOM</td>  <td>3*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>false</td> <td>-</td> <td>false</td> <td>false</td> <td>-</td> <td>1e9</td> </tr>\n       <tr> <td>7.5h</td>  <td>64.5s/it</td>  <td>58G</td>  <td>3*13428MB</td>  <td>3*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>false</td> <td>false</td> <td>-</td> <td>1e9</td> </tr>\n       <tr> <td>11.3h</td>  <td>95s/it</td>  <td>109G</td>  <td>3*12170MB</td>  <td>3*V100 16G</td>  <td>true</td>  <td>-</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>-</td> <td>1e9</td> </tr>\n       <tr> <td rowspan=\"17\">Pangu-2.6B</td>  <td></td>  <td>1.32s/it</td>  <td></td>  <td>1*49347MB</td>  <td>1*A100 80G</td>  <td>false</td>  <td>false</td>  <td>-</td>  <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> </tr>\n       <tr> <td></td>  <td>1.27s/it</td>  <td></td>  <td>1*52783MB</td>  <td>1*A100 80G</td>  <td>true</td>  <td>false</td>  <td>-</td>  <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> </tr>\n       <tr> <td></td>  <td>1.27s/it</td>  <td></td>  <td>1*52783MB</td>  <td>1*A100 80G</td>  <td>false</td>  <td>true</td>  <td>-</td>  <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> </tr>\n       <tr> <td></td>  <td>-</td>  <td></td>  <td>CUDA Error</td>  <td>1*A100 80G</td>  <td>false</td>  <td>true</td>  <td>0</td>  <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> </tr>\n       <tr> <td></td>  <td>-</td>  <td></td>  <td>CUDA Error</td>  <td>1*A100 80G</td>  <td>false</td>  <td>true</td>  <td>1</td>  <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> <td>-</td> </tr>\n       <tr> <td></td>  <td>-</td>  <td></td>  <td>CUDA Error</td>  <td>1*A100 80G</td>  <td>false</td>  <td>true</td>  <td>2</td>  <td>false</td> <td>-</td> <td>-</td> <td>false</td> <td>5e8</td> <td>-</td> </tr>\n       <tr> <td></td>  <td>9-11s/it</td>  <td></td>  <td>1*12537MB</td>  <td>1*A100 80G</td>  <td>false</td>  <td>true</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>2e8</td> <td>-</td> </tr>\n       <tr> <td></td>  <td>8-9s/it</td>  <td></td>  <td>1*13539MB</td>  <td>1*A100 80G</td>  <td>false</td>  <td>true</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>5e8</td> <td>-</td> </tr>\n       <tr> <td></td>  <td>9-11s/it</td>  <td></td>  <td>1*15041MB</td>  <td>1*A100 80G</td>  <td>false</td>  <td>true</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>false</td> <td>1e9</td> <td>-</td> </tr>\n       <tr> <td></td>  <td>8-9s/it</td>  <td></td>  <td>1*14887MB</td>  <td>1*A100 80G</td>  <td>false</td>  <td>true</td>  <td>2</td>  <td>true</td> <td>true</td> <td>-</td> <td>true</td> <td>5e8</td> <td>-</td> </tr>\n       <tr> <td></td>  <td>-</td>  <td></td>  <td>CUDA Error</td>  <td>1*A100 80G</td>  <td>false</td>  <td>true</td>  <td>3</td>  <td>false</td> <td>-</td> <td>false</td> <td>false</td> <td>5e8</td> <td>1e9</td> </tr>\n       <tr> <td></td>  <td>17-18s/it</td>  <td></td>  <td>1*16935MB</td>  <td>1*A100 80G</td>  <td>false</td>  <td>true</td>  <td>3</td>  <td>true</td> <td>true</td> <td>false</td> <td>false</td> <td>-</td> <td>1e9</td> </tr>\n       <tr> <td></td>  <td>20-21s/it</td>  <td></td>  <td>1*12219MB</td>  <td>1*A100 80G</td>  <td>false</td>  <td>true</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>false</td> <td>-</td> <td>1e9</td> </tr>\n       <tr> <td></td>  <td>19-20s/it</td>  <td></td>  <td>1*15981MB</td>  <td>1*A100 80G</td>  <td>false</td>  <td>true</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>true</td> <td>-</td> <td>1e9</td> </tr>\n       <tr> <td></td>  <td>22-23s/it</td>  <td></td>  <td>1*12023MB</td>  <td>1*A100 80G</td>  <td>false</td>  <td>true</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>true</td> <td>-</td> <td>5e8</td> </tr>\n       <tr> <td></td>  <td>20-21s/it</td>  <td></td>  <td>1*12023MB</td>  <td>1*A100 80G</td>  <td>false</td>  <td>true</td>  <td>3</td>  <td>true</td> <td>true</td> <td>true</td> <td>true</td> <td>-</td> <td>2e8</td> </tr>\n    </tbody>\n</table>\n"
  },
  {
    "path": "src/resources/config/deepspeed/README.md",
    "content": "# Deepspeed Configuration\n- deepspeed==0.8.2\n- transformers==4.26.1\n\n- using ```transformers.Trainer``` and ```transformers.TrainingArguments```\n\nExample of deepspeed config with key items explained:\n```bash\n{\n    \"train_batch_size\": \"auto\",\n    \"train_micro_batch_size_per_gpu\": \"auto\",\n    \"gradient_accumulation_steps\": \"auto\",\n    \"gradient_clipping\": \"auto\",\n    \"steps_per_print\": 100,\n    \"wall_clock_breakdown\": false # \"Enable timing of the latency of forward/backward/update training phases\"\n    \n    \"optimizer\": {\n        \"type\": \"Adam\",\n        \"params\": {\n          \"lr\": \"auto\",\n          \"betas\": \"auto\",\n          \"eps\": \"auto\",\n          \"weight_decay\": \"auto\"\n        }\n    },\n\n    \"scheduler\": {\n        \"type\": \"WarmupLR\",\n        \"params\": {\n            \"warmup_min_lr\": \"auto\",\n            \"warmup_max_lr\": \"auto\",\n            \"warmup_num_steps\": \"auto\"\n        }\n    },\n    \n    \"fp16\": {\n        \"enabled\": \"auto\",\n        \"auto_cast\": false, # automatically casts inputs to fp16\n        \"loss_scale\": 0, # a fp16 parameter representing the loss scaling value for FP16 training. The default value of 0.0 results in dynamic loss scaling, otherwise the value will be used for static fixed loss scaling\n        \"initial_scale_power\": 16,\n        \"loss_scale_window\": 1000,\n        \"hysteresis\": 2,\n        \"min_loss_scale\": 1\n    },\n\n    # \"BFLOAT16 requires hardware support (e.g., NVIDIA A100). Training with bfloat16 does not require loss scaling\"\n    \"bf16\": {\n        \"enabled\": \"auto\"\n    },\n    \n    \"zero_optimization\": {\n        \"stage\": [0|1|2|3], # \"Stage 0, 1, 2, and 3 refer to disabled, optimizer state partitioning, and optimizer+gradient state partitioning, and optimizer+gradient+parameter partitioning, respectively.\"\n        \"offload_optimizer\": {\n            \"device\": \"[cpu|nvme]\",\n            \"pin_memory\": true, # \"This feature can improve the throughput at the cost of making less memory available to other processes. Pinned memory is set aside to the specific process that requested it and its typically accessed much faster than normal CPU memory\"\n            # all nvme-related params\n            \"nvme_path\": \"/local_nvme\",\n            \"buffer_count\": 4,\n            \"fast_init\": false\n        },\n        \"offload_param\": {\n            \"device\": \"[cpu|nvme]\",\n            \"pin_memory\": true, # \"This feature can improve the throughput at the cost of making less memory available to other processes. Pinned memory is set aside to the specific process that requested it and its typically accessed much faster than normal CPU memory\"\n            # all nvme-related params\n            \"nvme_path\": \"/local_nvme\",\n            \"buffer_count\": 5,\n            \"buffer_size\": 1e8,\n            \"max_in_cpu\": 1e9\n        },\n        \"overlap_comm\": false, # \"if set to true, trades off increased GPU RAM usage to lower all-reduce latency. overlap_comm uses 4.5x the allgather_bucket_size and reduce_bucket_size values. So if they are set to 5e8, this requires a 9GB footprint (5e8 x 2Bytes x 2 x 4.5). Therefore, if you have a GPU with 8GB or less RAM, to avoid getting OOM-errors you will need to reduce those parameters to about 2e8, which would require 3.6GB\"\n        \"reduce_bucket_size\": \"auto\", # \"Number of elements reduced/allreduced at a time. Limits the memory required for the allgather for large model sizes. When set auto, it equals hidden_size*hidden_size\"\n        # only stage-2 params\n        \"allgather_partitions\": true,\n        \"allgather_bucket_size\": 5e8, # \"Number of elements allgathered at a time. Limits the memory required for the allgather for large model sizes\"\n        \"reduce_scatter\": true,\n        \"contiguous_gradients\" : true, # \"Copies the gradients to a contiguous buffer as they are produced. Avoids memory fragmentation during backward pass\"\n        \"round_robin_gradients\": [true|false], # \"Stage 1 and 2 optimization for CPU offloading that parallelizes gradient copying to CPU memory among ranks by fine-grained gradient partitioning. Performance benefit grows with gradient accumulation steps (more copying between optimizer steps) or GPU count (increased parallelism).\"\n        # only stage-3 params\n        \"stage3_max_live_parameters\" : 1e9, # \"The maximum number of parameters resident per GPU before releasing. Smaller values use less memory, but perform more communication. 1e9 would consume ~2GB\"\n        \"stage3_max_reuse_distance\" : 1e9, # \"Do not release a parameter if it will be reused within this threshold of parameters. Smaller values use less memory, but perform more communication.\"\n        \"stage3_prefetch_bucket_size\" : \"auto\", # \"The size of the fixed buffer for prefetching parameters. Smaller values use less memory, but can increase stalls due to communication. When set auto, it equals 0.9 * hidden_size * hidden_size\"\n        \"stage3_param_persistence_threshold\" : \"auto\", # \"Do not partition parameters smaller than this threshold. Smaller values use less memory, but can greatly increase communication (especially latency-bound messages). When set auto, it equals 10 * hidden_size\"\n        \"sub_group_size\" : 1e12, # controls the granularity in which parameters are updated during optimizer steps. Parameters are grouped into buckets of sub_group_size and each buckets is updated one at a time. When used with NVMe offload in ZeRO-Infinity, sub_group_size therefore controls the granularity in which model states are moved in and out of CPU memory from NVMe during the optimizer step. This prevents running out of CPU memory for extremely large models. \n        \"elastic_checkpoint\" : [true|false],\n        \"stage3_gather_16bit_weights_on_model_save\": true, # Consolidate the weights before saving the model by save_16bit_model(). Since the weights are partitioned across GPUs, they aren’t part of state_dict, so this function automatically gathers the weights when this option is enabled and then saves the fp16 model weights.\n        \"ignore_unused_parameters\": true # Unused parameters in modules may be unexpected in static networks, but could be normal in dynamic networks. This controls whether or not training should terminate with an error message when unused parameters are detected\n    },\n  \n    #  DeepSpeed Autotuner automatically discovers the optimal DeepSpeed configuration that delivers good training speed\n    \"autotuning\": {},\n  \n    # Flops Profiler helps users easily measure both the model training/inference speed (latency, throughput) and efficiency (floating-point operations per second, i.e., FLOPS) of a model and its submodules\n    \"flops_profiler\": {},\n    \n    \"activation_checkpointing\": {},\n  \n    \"sparse_attention\": {},\n  \n    # DeepSpeed Data Efficiency Library includes two techniques: curriculum learning and random layerwise token dropping (random-LTD).\n    \"data_efficiency\": {},\n  \n    # Compression has seven different components, including layer reduction, weight quantization, activation quantization, sparse pruning, row pruning, head pruning, and channel pruning\n    \"compression_training\": {}\n}\n```\n\nNote that the speical value ```auto``` in the configuration will be automatically replaced with the correct or most efficient value from ```transformers.TrainingArguments```\n\n## How to Choose Which ZeRO Stage and Offloads To Use For Best Performance\n\n### Guideline\n\n- Speed-wise (left is faster than right)\n\nStage 0 (DDP) > Stage 1 > Stage 2 > Stage 2 + offload > Stage 3 > Stage 3 + offloads\n\n- GPU Memory usage-wise (right is more GPU memory efficient than left)\n\nStage 0 (DDP) < Stage 1 < Stage 2 < Stage 2 + offload < Stage 3 < Stage 3 + offloads\n\n### Memory requirements\nSince Deepspeed ZeRO can offload memory to CPU (and NVMe) the framework provides utils that allow one to tell how much CPU and GPU memory will be needed depending on the number of GPUs being used.\n\nUsing \"bigscience/T0_3B\" and one GPU as example:\n```python\nfrom transformers import AutoModel\nfrom deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live\nfrom deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live\n\nmodel = AutoModel.from_pretrained(\"bigscience/T0_3B\")\n# stage 1 and 2\nestimate_zero2_model_states_mem_needs_all_live(model, num_gpus_per_node=1, num_nodes=1)\n# stage 3\nestimate_zero3_model_states_mem_needs_all_live(model, num_gpus_per_node=1, num_nodes=1)\n```\n\n### Steps to find the optimal configuration\nSo when you want to get the fastest execution while fitting into minimal number of GPUs, here is the process you could follow. We start with the fastest approach and if running into GPU OOM we then go to the next slower approach, but which will use less GPU memory. And so on and so forth.\n\nFirst of all set batch size to 1 (you can always use gradient accumulation for any desired effective batch size).\n\n1. Enable ```--gradient_checkpointing 1``` (HF Trainer) or directly ```model.gradient_checkpointing_enable()``` - if OOM then\n\n2. Try ZeRO stage 2 first. if OOM then\n\n3. Try ZeRO stage 2 + ```offload_optimizer``` - if OOM then\n\n4. Switch to ZeRO stage 3 - if OOM then\n\n5. Enable ```offload_param``` to ```cpu``` - if OOM then\n\n6. Enable ```offload_optimizer``` to ```cpu``` - if OOM then\n\n7. If you still can’t fit a batch size of 1 first check various default values and lower them if you can. For example, if you use ```generate``` and you don’t use a wide search beam make it narrower as it’d take a lot of memory.\n\n8. Definitely use mixed half-precision over fp32 - so bf16 on Ampere and higher GPUs and fp16 on older gpu architectures.\n\n9. If you still OOM you could add more hardware or enable ZeRO-Infinity - that is switch offloads ```offload_param``` and ```offload_optimizer``` to ```nvme```. You need to make sure it’s a very fast nvme.\n\nYou can, of course, work through these steps in reverse by starting with the most GPU memory efficient config and then going backwards. Or try bi-secting it.\n\n\n## Tricks & Troubleshooting\n- If you are training from scratch, try to have tensors with shapes that are divisible by 16 (e.g. hidden size). For batch size try divisible by 2 at least. There are wave and tile quanitization divisibility that is hardware-specific if you want to squeeze even higher performance from your GPUs.\n\n- It’s possible to adjust ZeRO-3 configuration to make it perform closer to ZeRO-2:\n    - set ```stage3_param_persistence_threshold``` to a very large number - larger than the largest parameter, e.g., ```6 * hidden_size * hidden_size```. This will keep the parameters on the GPUs. \n    - turn off ```offload_params``` since ZeRO-2 doesn’t have that option.\n\n- ```overlap_comm```: if true, trades off increased GPU RAM usage to lower all-reduce latency. ```overlap_comm``` uses 4.5x the ```allgather_bucket_size``` and ```reduce_bucket_size``` values. So if they are set to 5e8, this requires a 9GB footprint (5e8 x 2Bytes x 2 x 4.5). Therefore, if you have a GPU with 8GB or less RAM, to avoid getting OOM-errors you will need to reduce those parameters to about 2e8, which would require 3.6GB. You will want to do the same on larger capacity GPU as well, if you’re starting to hit OOM\n\n- ```sub_group_size```: You can leave it to default value when not using NVMe offload. You may want to change its default value in the following cases: \n  - Running into OOM during optimizer step: Reduce sub_group_size to reduce memory utilization of temporary buffers; \n  - Optimizer Step is taking a long time: Increase sub_group_size to improve bandwidth utilization as a result of the increased data buffers.\n\n- ```activation_checkpointing```: activation checkpointing and gradient checkpointing refer to the same methodology. But enabling ```activation_checkpointing``` in deepSpeed config has no effect on huggingface transformers. If you want to use a HF Transformers models you can do model.gradient_checkpointing_enable() or use --gradient_checkpointing in the HF Trainer, which will automatically enable this for you.\n\n- Using fp16 and you see in your log that Deepspeed reports OVERFLOW! as follows:\n```bash\n[deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 262144, reducing to 262144\n```\nthat means that the Deepspeed loss scaler can’t figure out a scaling co-efficient that overcomes loss overflow. In this case you usually need to raise the value of ```initial_scale_power``` to 32 which will typically solve the problem.\n\n\n# Main Resources\n\n- [Huggingface Deepspeed Integration](https://huggingface.co/docs/transformers/main_classes/deepspeed#deepspeed-integration)\n- [Deepspeed Docs](https://www.deepspeed.ai/)\n- [Deepspeed Github](https://github.com/microsoft/DeepSpeed)"
  },
  {
    "path": "src/resources/config/deepspeed/stage-1.json",
    "content": "{\n  \"train_batch_size\": \"auto\",\n  \"train_micro_batch_size_per_gpu\": \"auto\",\n  \"gradient_accumulation_steps\": \"auto\",\n  \"gradient_clipping\": \"auto\",\n  \"steps_per_print\": 100,\n  \"wall_clock_breakdown\": false,\n  \"zero_optimization\": {\n    \"stage\": 1\n  },\n  \"fp16\": {\n    \"enabled\": \"auto\",\n    \"auto_cast\": false,\n    \"loss_scale\": 0,\n    \"initial_scale_power\": 32,\n    \"loss_scale_window\": 1000,\n    \"hysteresis\": 2,\n    \"min_loss_scale\": 1\n  },\n  \"bf16\": {\n      \"enabled\": \"auto\"\n  },\n  \"optimizer\": {\n    \"type\": \"Adam\",\n    \"params\": {\n      \"lr\": \"auto\",\n      \"betas\": \"auto\",\n      \"eps\": \"auto\",\n      \"weight_decay\": \"auto\"\n    }\n  }\n}\n"
  },
  {
    "path": "src/resources/config/deepspeed/stage-2.json",
    "content": "{\n  \"train_batch_size\": \"auto\",\n  \"train_micro_batch_size_per_gpu\": \"auto\",\n  \"gradient_accumulation_steps\": \"auto\",\n  \"gradient_clipping\": \"auto\",\n  \"steps_per_print\": 100,\n  \"wall_clock_breakdown\": false,\n  \"zero_optimization\": {\n    \"stage\": 2,\n    \"offload_optimizer\": {\n      \"device\": \"cpu\",\n      \"pin_memory\": true\n    },\n    \"overlap_comm\": false,\n    \"allgather_partitions\": true,\n    \"allgather_bucket_size\": 5e8,\n    \"reduce_scatter\": true,\n    \"reduce_bucket_size\": 5e8,\n    \"contiguous_gradients\" : true,\n    \"round_robin_gradients\": true\n  },\n  \"fp16\": {\n    \"enabled\": \"auto\",\n    \"auto_cast\": false,\n    \"loss_scale\": 0,\n    \"initial_scale_power\": 32,\n    \"loss_scale_window\": 1000,\n    \"hysteresis\": 2,\n    \"min_loss_scale\": 1\n  },\n  \"bf16\": {\n      \"enabled\": \"auto\"\n  },\n  \"optimizer\": {\n    \"type\": \"Adam\",\n    \"params\": {\n      \"lr\": \"auto\",\n      \"betas\": \"auto\",\n      \"eps\": \"auto\",\n      \"weight_decay\": \"auto\"\n    }\n  }\n}\n"
  },
  {
    "path": "src/resources/config/deepspeed/stage-3-no_trainer.json",
    "content": "{\n  \"train_batch_size\": \"auto\",\n  \"train_micro_batch_size_per_gpu\": \"auto\",\n  \"gradient_accumulation_steps\": \"auto\",\n  \"gradient_clipping\": \"auto\",\n  \"steps_per_print\": 10,\n  \"wall_clock_breakdown\": false,\n  \"zero_optimization\": {\n    \"stage\": 3,\n    \"offload_optimizer\": {\n      \"device\": \"cpu\",\n      \"pin_memory\": false\n    },\n    \"offload_param\": {\n      \"device\": \"cpu\",\n      \"pin_memory\": false\n    },\n    \"overlap_comm\": false,\n    \"stage3_max_live_parameters\" : 1e3,\n    \"stage3_max_reuse_distance\" : 1e3,\n    \"reduce_bucket_size\": \"auto\",\n    \"stage3_prefetch_bucket_size\" : \"auto\",\n    \"stage3_param_persistence_threshold\" : \"auto\",\n    \"sub_group_size\" : 1e3,\n    \"stage3_gather_16bit_weights_on_model_save\": true,\n    \"ignore_unused_parameters\": true\n  },\n  \"fp16\": {\n    \"enabled\": \"auto\",\n    \"auto_cast\": false,\n    \"loss_scale\": 0,\n    \"initial_scale_power\": 16,\n    \"loss_scale_window\": 1000,\n    \"hysteresis\": 2,\n    \"min_loss_scale\": 1\n  },\n  \"bf16\": {\n      \"enabled\": \"auto\"\n  },\n  \"optimizer\": {\n    \"type\": \"Adam\",\n    \"params\": {\n      \"lr\": \"auto\",\n      \"betas\": \"auto\",\n      \"eps\": \"auto\",\n      \"weight_decay\": \"auto\"\n    }\n  },\n  \"scheduler\": {\n    \"type\": \"OneCycle\",\n    \"params\": {\n      \"cycle_min_lr\": 0,\n      \"cycle_max_lr\": \"auto\",\n      \"cycle_first_step_size\": \"auto\",\n      \"cycle_momentum\": true,\n      \"cycle_min_mom\": 0.85,\n      \"cycle_max_mom\": 0.99\n    }\n  },\n  \"tensorboard\": {\n    \"enabled\": true,\n    \"output_path\": \"out/tensorboard/\"\n  }\n}\n"
  },
  {
    "path": "src/resources/config/deepspeed/stage-3.json",
    "content": "{\n  \"train_batch_size\": \"auto\",\n  \"train_micro_batch_size_per_gpu\": \"auto\",\n  \"gradient_accumulation_steps\": \"auto\",\n  \"gradient_clipping\": \"auto\",\n  \"steps_per_print\": 10,\n  \"wall_clock_breakdown\": false,\n  \"zero_optimization\": {\n    \"stage\": 3,\n    \"offload_optimizer\": {\n      \"device\": \"cpu\",\n      \"pin_memory\": true\n    },\n    \"offload_param\": {\n      \"device\": \"cpu\",\n      \"pin_memory\": true\n    },\n    \"overlap_comm\": false,\n    \"stage3_max_live_parameters\" : 1e3,\n    \"stage3_max_reuse_distance\" : 1e3,\n    \"reduce_bucket_size\": \"auto\",\n    \"stage3_prefetch_bucket_size\" : \"auto\",\n    \"stage3_param_persistence_threshold\" : \"auto\",\n    \"sub_group_size\" : 1e3,\n    \"stage3_gather_16bit_weights_on_model_save\": true,\n    \"ignore_unused_parameters\": true\n  },\n  \"fp16\": {\n    \"enabled\": \"auto\",\n    \"auto_cast\": false,\n    \"loss_scale\": 0,\n    \"initial_scale_power\": 32,\n    \"loss_scale_window\": 1000,\n    \"hysteresis\": 2,\n    \"min_loss_scale\": 1\n  },\n  \"bf16\": {\n      \"enabled\": \"auto\"\n  },\n  \"optimizer\": {\n    \"type\": \"Adam\",\n    \"params\": {\n      \"lr\": \"auto\",\n      \"betas\": \"auto\",\n      \"eps\": \"auto\",\n      \"weight_decay\": \"auto\"\n    }\n  }\n}\n"
  },
  {
    "path": "src/resources/config/log.conf",
    "content": "[loggers]\nkeys=root,main\n\n[handlers]\nkeys=consoleHanlder,fileHandler\n\n[formatters]\nkeys=fmt\n\n[logger_root]\nlevel=DEBUG\nhandlers=consoleHanlder,fileHandler\n\n[logger_main]\nlevel=DEBUG\nqualname=main\nhandlers=fileHandler\n\n[handler_consoleHanlder]\nclass=StreamHandler\nlevel=INFO\nformatter=fmt\nargs=(sys.stdout,)\n\n[handler_fileHandler]\nclass=logging.handlers.TimedRotatingFileHandler\nlevel=INFO\nformatter=fmt\nargs=('output.log', 'd', 1, 7)\n\n[formatter_fmt]\nformat=%(asctime)s %(levelname)-4s [%(filename)s:%(lineno)s]  %(message)s\ndatefmt=%Y/%m/%d %H:%M:%S\n"
  },
  {
    "path": "src/resources/config/ppo_model/ppo_config_glm-350M-chinese.yml",
    "content": "train:\n  seq_length: 512\n  epochs: 1\n  total_steps: 50000\n  batch_size: 1\n  checkpoint_interval: 1000\n  eval_interval: 1000\n  pipeline: \"GLMPipeline\"\n  trainer: \"AcceleratePPOTrainer\"\n  tracker: \"tensorboard\"\n  logging_dir: \"tensorboard\"\n\nmodel:\n  model_path: \"/root/autodl-tmp/Data/models/glm-350M-chinese\"\n  num_layers_unfrozen: -1\n  model_arch_type: \"seq2seq\"\n\ntokenizer:\n  tokenizer_path: \"/root/autodl-tmp/Data/models/glm-350M-chinese\"\n  padding_side: \"right\"\n\noptimizer:\n  name: \"adamw\"\n  kwargs:\n    lr: 5.0e-6\n    betas: [0.9, 0.999]\n    eps: 1.0e-8\n    weight_decay: 0.01\n\nscheduler:\n  name: \"cosine_annealing\"\n  kwargs:\n    T_max: 100000\n    eta_min: 5.0e-6\n\nmethod:\n  name: \"ppoconfig\"\n  num_rollouts: 512\n  chunk_size: 1\n  ppo_epochs: 1\n  init_kl_coef: 0.1\n  target: 6\n  horizon: 10000\n  gamma: 1\n  lam: 0.95\n  cliprange: 0.2\n  cliprange_value: 0.2\n  vf_coef: 0.2\n  scale_reward: False\n  ref_mean: null\n  ref_std: null\n  cliprange_reward: 10\n  gen_kwargs:\n    max_new_tokens: 100\n    top_k: 50\n    top_p: 0.8\n    do_sample: True\n    temperature: 1.0"
  },
  {
    "path": "src/resources/config/ppo_model/ppo_config_glm-small.yml",
    "content": "train:\n  seq_length: 256\n  epochs: 1\n  total_steps: 10000\n  batch_size: 2\n  checkpoint_interval: 100\n  eval_interval: 100\n  pipeline: \"GLMPipeline\"\n  trainer: \"AcceleratePPOTrainer\"\n  tracker: \"tensorboard\"\n  logging_dir: \"tensorboard\"\n\nmodel:\n  model_path: \"/root/autodl-tmp/Data/models/glm-350M-chinese\"\n  num_layers_unfrozen: -1\n  model_arch_type: \"seq2seq\"\n\ntokenizer:\n  tokenizer_path: \"/root/autodl-tmp/Data/models/glm-350M-chinese\"\n  padding_side: \"left\"\n\noptimizer:\n  name: \"adamw\"\n  kwargs:\n    lr: 5.0e-6\n    betas: [0.9, 0.999]\n    eps: 1.0e-8\n    weight_decay: 0.01\n\nscheduler:\n  name: \"cosine_annealing\"\n  kwargs:\n    T_max: 100000\n    eta_min: 5.0e-6\n\nmethod:\n  name: \"ppoconfig\"\n  num_rollouts: 2\n  chunk_size: 1\n  ppo_epochs: 1\n  init_kl_coef: 0.1\n  target: 6\n  horizon: 10000\n  gamma: 1\n  lam: 0.95\n  cliprange: 0.2\n  cliprange_value: 0.2\n  vf_coef: 0.2\n  scale_reward: False\n  ref_mean: null\n  ref_std: null\n  cliprange_reward: 10\n  gen_kwargs:\n    max_new_tokens: 5\n    top_k: 50\n    top_p: 0.8\n    do_sample: True\n    temperature: 1.0"
  },
  {
    "path": "src/resources/config/ppo_model/ppo_config_pangu-2.6B.yml",
    "content": "train:\n  seq_length: 512\n  epochs: 1\n  total_steps: 50000\n  batch_size: 4\n  checkpoint_interval: 1000\n  eval_interval: 1000\n  pipeline: \"PanguPipeline\"\n  trainer: \"AcceleratePPOTrainer\"\n  tracker: \"tensorboard\"\n  logging_dir: \"tensorboard\"\n\nmodel:\n  model_path: \"/root/autodl-tmp/Data/models/pangu-2.6B\"\n  num_layers_unfrozen: -1\n\ntokenizer:\n  tokenizer_path: \"/root/autodl-tmp/Data/models/pangu-2.6B\"\n  padding_side: \"left\"\n\noptimizer:\n  name: \"adamw\"\n  kwargs:\n    lr: 5.0e-6\n    betas: [0.9, 0.999]\n    eps: 1.0e-8\n    weight_decay: 0.01\n\nscheduler:\n  name: \"cosine_annealing\"\n  kwargs:\n    T_max: 100000\n    eta_min: 5.0e-6\n\nmethod:\n  name: \"ppoconfig\"\n  num_rollouts: 512\n  chunk_size: 4\n  ppo_epochs: 1\n  init_kl_coef: 0.1\n  target: 6\n  horizon: 10000\n  gamma: 1\n  lam: 0.95\n  cliprange: 0.2\n  cliprange_value: 0.2\n  vf_coef: 0.2\n  scale_reward: False\n  ref_mean: null\n  ref_std: null\n  cliprange_reward: 10\n  gen_kwargs:\n    max_new_tokens: 100\n    top_k: 50\n    top_p: 0.8\n    do_sample: True\n    temperature: 1.0"
  },
  {
    "path": "src/resources/config/ppo_model/ppo_config_pangu-350M.yml",
    "content": "train:\n  seq_length: 512\n  epochs: 1\n  total_steps: 100000\n  batch_size: 16\n  checkpoint_interval: 1000\n  eval_interval: 1000\n  pipeline: \"PanguPipeline\"\n  trainer: \"AcceleratePPOTrainer\"\n  tracker: \"tensorboard\"\n  logging_dir: \"tensorboard\"\n\nmodel:\n  model_path: \"/root/autodl-tmp/Data/models/pangu-350M\"\n  num_layers_unfrozen: -1\n\ntokenizer:\n  tokenizer_path: \"/root/autodl-tmp/Data/models/pangu-350M\"\n  padding_side: \"left\"\n\noptimizer:\n  name: \"adamw\"\n  kwargs:\n    lr: 5.0e-6\n    betas: [0.9, 0.999]\n    eps: 1.0e-8\n    weight_decay: 0.01\n\nscheduler:\n  name: \"cosine_annealing\"\n  kwargs:\n    T_max: 100000\n    eta_min: 5.0e-6\n\nmethod:\n  name: \"ppoconfig\"\n  num_rollouts: 512\n  chunk_size: 16\n  ppo_epochs: 1\n  init_kl_coef: 0.1\n  target: 6\n  horizon: 10000\n  gamma: 1\n  lam: 0.95\n  cliprange: 0.2\n  cliprange_value: 0.2\n  vf_coef: 0.2\n  scale_reward: False\n  ref_mean: null\n  ref_std: null\n  cliprange_reward: 10\n  gen_kwargs:\n    max_new_tokens: 100\n    top_k: 50\n    top_p: 0.8\n    do_sample: True\n    temperature: 1.0"
  },
  {
    "path": "src/resources/dict/jieba_custom_dict.txt",
    "content": "<s>\n</s>\n<eot>\n<unk>\n<sep>\n<pad>"
  },
  {
    "path": "src/resources/eval/ceval/subject_mapping.json",
    "content": "{\n  \"computer_network\": [\n    \"Computer Network\",\n    \"计算机网络\",\n    \"STEM\"\n  ],\n  \"operating_system\": [\n    \"Operating System\",\n    \"操作系统\",\n    \"STEM\"\n  ],\n  \"computer_architecture\": [\n    \"Computer Architecture\",\n    \"计算机组成\",\n    \"STEM\"\n  ],\n  \"college_programming\": [\n    \"College Programming\",\n    \"大学编程\",\n    \"STEM\"\n  ],\n  \"college_physics\": [\n    \"College Physics\",\n    \"大学物理\",\n    \"STEM\"\n  ],\n  \"college_chemistry\": [\n    \"College Chemistry\",\n    \"大学化学\",\n    \"STEM\"\n  ],\n  \"advanced_mathematics\": [\n    \"Advanced Mathematics\",\n    \"高等数学\",\n    \"STEM\"\n  ],\n  \"probability_and_statistics\": [\n    \"Probability and Statistics\",\n    \"概率统计\",\n    \"STEM\"\n  ],\n  \"discrete_mathematics\": [\n    \"Discrete Mathematics\",\n    \"离散数学\",\n    \"STEM\"\n  ],\n  \"electrical_engineer\": [\n    \"Electrical Engineer\",\n    \"注册电气工程师\",\n    \"STEM\"\n  ],\n  \"metrology_engineer\": [\n    \"Metrology Engineer\",\n    \"注册计量师\",\n    \"STEM\"\n  ],\n  \"high_school_mathematics\": [\n    \"High School Mathematics\",\n    \"高中数学\",\n    \"STEM\"\n  ],\n  \"high_school_physics\": [\n    \"High School Physics\",\n    \"高中物理\",\n    \"STEM\"\n  ],\n  \"high_school_chemistry\": [\n    \"High School Chemistry\",\n    \"高中化学\",\n    \"STEM\"\n  ],\n  \"high_school_biology\": [\n    \"High School Biology\",\n    \"高中生物\",\n    \"STEM\"\n  ],\n  \"middle_school_mathematics\": [\n    \"Middle School Mathematics\",\n    \"初中数学\",\n    \"STEM\"\n  ],\n  \"middle_school_biology\": [\n    \"Middle School Biology\",\n    \"初中生物\",\n    \"STEM\"\n  ],\n  \"middle_school_physics\": [\n    \"Middle School Physics\",\n    \"初中物理\",\n    \"STEM\"\n  ],\n  \"middle_school_chemistry\": [\n    \"Middle School Chemistry\",\n    \"初中化学\",\n    \"STEM\"\n  ],\n  \"veterinary_medicine\": [\n    \"Veterinary Medicine\",\n    \"兽医学\",\n    \"STEM\"\n  ],\n  \"college_economics\": [\n    \"College Economics\",\n    \"大学经济学\",\n    \"Social Science\"\n  ],\n  \"business_administration\": [\n    \"Business Administration\",\n    \"工商管理\",\n    \"Social Science\"\n  ],\n  \"marxism\": [\n    \"Marxism\",\n    \"马克思主义基本原理\",\n    \"Social Science\"\n  ],\n  \"mao_zedong_thought\": [\n    \"Mao Zedong Thought\",\n    \"毛泽东思想和中国特色社会主义理论体系概论\",\n    \"Social Science\"\n  ],\n  \"education_science\": [\n    \"Education Science\",\n    \"教育学\",\n    \"Social Science\"\n  ],\n  \"teacher_qualification\": [\n    \"Teacher Qualification\",\n    \"教师资格\",\n    \"Social Science\"\n  ],\n  \"high_school_politics\": [\n    \"High School Politics\",\n    \"高中政治\",\n    \"Social Science\"\n  ],\n  \"high_school_geography\": [\n    \"High School Geography\",\n    \"高中地理\",\n    \"Social Science\"\n  ],\n  \"middle_school_politics\": [\n    \"Middle School Politics\",\n    \"初中政治\",\n    \"Social Science\"\n  ],\n  \"middle_school_geography\": [\n    \"Middle School Geography\",\n    \"初中地理\",\n    \"Social Science\"\n  ],\n  \"modern_chinese_history\": [\n    \"Modern Chinese History\",\n    \"近代史纲要\",\n    \"Humanities\"\n  ],\n  \"ideological_and_moral_cultivation\": [\n    \"Ideological and Moral Cultivation\",\n    \"思想道德修养与法律基础\",\n    \"Humanities\"\n  ],\n  \"logic\": [\n    \"Logic\",\n    \"逻辑学\",\n    \"Humanities\"\n  ],\n  \"law\": [\n    \"Law\",\n    \"法学\",\n    \"Humanities\"\n  ],\n  \"chinese_language_and_literature\": [\n    \"Chinese Language and Literature\",\n    \"中国语言文学\",\n    \"Humanities\"\n  ],\n  \"art_studies\": [\n    \"Art Studies\",\n    \"艺术学\",\n    \"Humanities\"\n  ],\n  \"professional_tour_guide\": [\n    \"Professional Tour Guide\",\n    \"导游资格\",\n    \"Humanities\"\n  ],\n  \"legal_professional\": [\n    \"Legal Professional\",\n    \"法律职业资格\",\n    \"Humanities\"\n  ],\n  \"high_school_chinese\": [\n    \"High School Chinese\",\n    \"高中语文\",\n    \"Humanities\"\n  ],\n  \"high_school_history\": [\n    \"High School History\",\n    \"高中历史\",\n    \"Humanities\"\n  ],\n  \"middle_school_history\": [\n    \"Middle School History\",\n    \"初中历史\",\n    \"Humanities\"\n  ],\n  \"civil_servant\": [\n    \"Civil Servant\",\n    \"公务员\",\n    \"Other\"\n  ],\n  \"sports_science\": [\n    \"Sports Science\",\n    \"体育学\",\n    \"Other\"\n  ],\n  \"plant_protection\": [\n    \"Plant Protection\",\n    \"植物保护\",\n    \"Other\"\n  ],\n  \"basic_medicine\": [\n    \"Basic Medicine\",\n    \"基础医学\",\n    \"Other\"\n  ],\n  \"clinical_medicine\": [\n    \"Clinical Medicine\",\n    \"临床医学\",\n    \"Other\"\n  ],\n  \"urban_and_rural_planner\": [\n    \"Urban and Rural Planner\",\n    \"注册城乡规划师\",\n    \"Other\"\n  ],\n  \"accountant\": [\n    \"Accountant\",\n    \"注册会计师\",\n    \"Other\"\n  ],\n  \"fire_engineer\": [\n    \"Fire Engineer\",\n    \"注册消防工程师\",\n    \"Other\"\n  ],\n  \"environmental_impact_assessment_engineer\": [\n    \"Environmental Impact Assessment Engineer\",\n    \"环境影响评价工程师\",\n    \"Other\"\n  ],\n  \"tax_accountant\": [\n    \"Tax Accountant\",\n    \"税务师\",\n    \"Other\"\n  ],\n  \"physician\": [\n    \"Physician\",\n    \"医师资格\",\n    \"Other\"\n  ]\n}"
  },
  {
    "path": "src/resources/eval/ceval/submission_example.json",
    "content": "{\n    \"chinese_language_and_literature\": {\n        \"0\": \"A\",\n        \"1\": \"B\",\n        \"2\": \"B\"\n    },\n    \"clinical_medicine\": {\n        \"0\": \"C\",\n        \"1\": \"D\",\n        \"2\": \"C\"\n    },\n    \"sports_science\": {\n        \"0\": \"D\",\n        \"1\": \"C\",\n        \"2\": \"B\"\n    },\n    \"civil_servant\": {\n        \"0\": \"A\",\n        \"1\": \"B\",\n        \"2\": \"D\"\n    },\n    \"veterinary_medicine\": {\n        \"0\": \"D\",\n        \"1\": \"B\",\n        \"2\": \"C\"\n    },\n    \"middle_school_chemistry\": {\n        \"0\": \"A\",\n        \"1\": \"A\",\n        \"2\": \"C\"\n    },\n    \"middle_school_history\": {\n        \"0\": \"C\",\n        \"1\": \"D\",\n        \"2\": \"A\"\n    },\n    \"middle_school_geography\": {\n        \"0\": \"C\",\n        \"1\": \"C\",\n        \"2\": \"C\"\n    },\n    \"middle_school_politics\": {\n        \"0\": \"B\",\n        \"1\": \"B\",\n        \"2\": \"D\"\n    },\n    \"middle_school_mathematics\": {\n        \"0\": \"B\",\n        \"1\": \"D\",\n        \"2\": \"C\"\n    },\n    \"middle_school_physics\": {\n        \"0\": \"A\",\n        \"1\": \"A\",\n        \"2\": \"D\"\n    },\n    \"middle_school_biology\": {\n        \"0\": \"C\",\n        \"1\": \"A\",\n        \"2\": \"A\"\n    },\n    \"physician\": {\n        \"0\": \"B\",\n        \"1\": \"D\",\n        \"2\": \"D\"\n    },\n    \"basic_medicine\": {\n        \"0\": \"D\",\n        \"1\": \"B\",\n        \"2\": \"D\"\n    },\n    \"modern_chinese_history\": {\n        \"0\": \"C\",\n        \"1\": \"C\",\n        \"2\": \"D\"\n    },\n    \"college_chemistry\": {\n        \"0\": \"B\",\n        \"1\": \"A\",\n        \"2\": \"C\"\n    },\n    \"college_physics\": {\n        \"0\": \"C\",\n        \"1\": \"C\",\n        \"2\": \"A\"\n    },\n    \"college_economics\": {\n        \"0\": \"A\",\n        \"1\": \"A\",\n        \"2\": \"B\"\n    },\n    \"college_programming\": {\n        \"0\": \"D\",\n        \"1\": \"D\",\n        \"2\": \"D\"\n    },\n    \"professional_tour_guide\": {\n        \"0\": \"B\",\n        \"1\": \"B\",\n        \"2\": \"A\"\n    },\n    \"business_administration\": {\n        \"0\": \"C\",\n        \"1\": \"B\",\n        \"2\": \"A\"\n    },\n    \"ideological_and_moral_cultivation\": {\n        \"0\": \"A\",\n        \"1\": \"C\",\n        \"2\": \"A\"\n    },\n    \"operating_system\": {\n        \"0\": \"C\",\n        \"1\": \"A\",\n        \"2\": \"D\"\n    },\n    \"teacher_qualification\": {\n        \"0\": \"C\",\n        \"1\": \"C\",\n        \"2\": \"B\"\n    },\n    \"education_science\": {\n        \"0\": \"D\",\n        \"1\": \"D\",\n        \"2\": \"B\"\n    },\n    \"plant_protection\": {\n        \"0\": \"B\",\n        \"1\": \"D\",\n        \"2\": \"B\"\n    },\n    \"probability_and_statistics\": {\n        \"0\": \"C\",\n        \"1\": \"C\",\n        \"2\": \"A\"\n    },\n    \"mao_zedong_thought\": {\n        \"0\": \"B\",\n        \"1\": \"B\",\n        \"2\": \"B\"\n    },\n    \"law\": {\n        \"0\": \"D\",\n        \"1\": \"B\",\n        \"2\": \"B\"\n    },\n    \"legal_professional\": {\n        \"0\": \"D\",\n        \"1\": \"C\",\n        \"2\": \"D\"\n    },\n    \"accountant\": {\n        \"0\": \"A\",\n        \"1\": \"B\",\n        \"2\": \"D\"\n    },\n    \"urban_and_rural_planner\": {\n        \"0\": \"B\",\n        \"1\": \"C\",\n        \"2\": \"A\"\n    },\n    \"fire_engineer\": {\n        \"0\": \"A\",\n        \"1\": \"C\",\n        \"2\": \"C\"\n    },\n    \"electrical_engineer\": {\n        \"0\": \"B\",\n        \"1\": \"A\",\n        \"2\": \"A\"\n    },\n    \"metrology_engineer\": {\n        \"0\": \"B\",\n        \"1\": \"B\",\n        \"2\": \"C\"\n    },\n    \"environmental_impact_assessment_engineer\": {\n        \"0\": \"D\",\n        \"1\": \"B\",\n        \"2\": \"A\"\n    },\n    \"discrete_mathematics\": {\n        \"0\": \"C\",\n        \"1\": \"A\",\n        \"2\": \"C\"\n    },\n    \"tax_accountant\": {\n        \"0\": \"A\",\n        \"1\": \"C\",\n        \"2\": \"C\"\n    },\n    \"art_studies\": {\n        \"0\": \"A\",\n        \"1\": \"B\",\n        \"2\": \"A\"\n    },\n    \"computer_architecture\": {\n        \"0\": \"A\",\n        \"1\": \"C\",\n        \"2\": \"A\"\n    },\n    \"computer_network\": {\n        \"0\": \"D\",\n        \"1\": \"B\",\n        \"2\": \"C\"\n    },\n    \"logic\": {\n        \"0\": \"B\",\n        \"1\": \"A\",\n        \"2\": \"B\"\n    },\n    \"marxism\": {\n        \"0\": \"B\",\n        \"1\": \"B\",\n        \"2\": \"B\"\n    },\n    \"high_school_chemistry\": {\n        \"0\": \"B\",\n        \"1\": \"C\",\n        \"2\": \"D\"\n    },\n    \"high_school_history\": {\n        \"0\": \"A\",\n        \"1\": \"D\",\n        \"2\": \"A\"\n    },\n    \"high_school_geography\": {\n        \"0\": \"B\",\n        \"1\": \"A\",\n        \"2\": \"D\"\n    },\n    \"high_school_politics\": {\n        \"0\": \"B\",\n        \"1\": \"A\",\n        \"2\": \"B\"\n    },\n    \"high_school_mathematics\": {\n        \"0\": \"D\",\n        \"1\": \"D\",\n        \"2\": \"A\"\n    },\n    \"high_school_physics\": {\n        \"0\": \"C\",\n        \"1\": \"C\",\n        \"2\": \"C\"\n    },\n    \"high_school_biology\": {\n        \"0\": \"B\",\n        \"1\": \"A\",\n        \"2\": \"B\"\n    },\n    \"high_school_chinese\": {\n        \"0\": \"D\",\n        \"1\": \"B\",\n        \"2\": \"A\"\n    },\n    \"advanced_mathematics\": {\n        \"0\": \"C\",\n        \"1\": \"D\",\n        \"2\": \"A\"\n    }\n}"
  },
  {
    "path": "src/resources/eval/mmlu/subject_mapping.json",
    "content": "{\n  \"moral_scenarios\": [\n    \"Moral Scenarios\",\n    \"Humanities\"\n  ],\n  \"us_foreign_policy\": [\n    \"US Foreign Policy\",\n    \"Social Sciences\"\n  ],\n  \"public_relations\": [\n    \"Public Relations\",\n    \"Social Sciences\"\n  ],\n  \"global_facts\": [\n    \"Global Facts\",\n    \"Other\"\n  ],\n  \"electrical_engineering\": [\n    \"Electrical Engineering\",\n    \"STEM\"\n  ],\n  \"astronomy\": [\n    \"Astronomy\",\n    \"STEM\"\n  ],\n  \"business_ethics\": [\n    \"Business Ethics\",\n    \"Other\"\n  ],\n  \"jurisprudence\": [\n    \"Jurisprudence\",\n    \"Humanities\"\n  ],\n  \"high_school_chemistry\": [\n    \"High School Chemistry\",\n    \"STEM\"\n  ],\n  \"college_physics\": [\n    \"College Physics\",\n    \"STEM\"\n  ],\n  \"professional_psychology\": [\n    \"Professional Psychology\",\n    \"Social Sciences\"\n  ],\n  \"marketing\": [\n    \"Marketing\",\n    \"Other\"\n  ],\n  \"management\": [\n    \"Management\",\n    \"Other\"\n  ],\n  \"virology\": [\n    \"Virology\",\n    \"Other\"\n  ],\n  \"international_law\": [\n    \"International Law\",\n    \"Humanities\"\n  ],\n  \"high_school_macroeconomics\": [\n    \"High School Macroeconomics\",\n    \"Social Sciences\"\n  ],\n  \"prehistory\": [\n    \"Prehistory\",\n    \"Humanities\"\n  ],\n  \"abstract_algebra\": [\n    \"Abstract Algebra\",\n    \"STEM\"\n  ],\n  \"high_school_physics\": [\n    \"High School Physics\",\n    \"STEM\"\n  ],\n  \"formal_logic\": [\n    \"Formal Logic\",\n    \"Humanities\"\n  ],\n  \"college_medicine\": [\n    \"College Medicine\",\n    \"Other\"\n  ],\n  \"high_school_us_history\": [\n    \"High School US History\",\n    \"Humanities\"\n  ],\n  \"moral_disputes\": [\n    \"Moral Disputes\",\n    \"Humanities\"\n  ],\n  \"high_school_european_history\": [\n    \"High School European History\",\n    \"Humanities\"\n  ],\n  \"clinical_knowledge\": [\n    \"Clinical Knowledge\",\n    \"Other\"\n  ],\n  \"world_religions\": [\n    \"World Religions\",\n    \"Humanities\"\n  ],\n  \"high_school_microeconomics\": [\n    \"High School Microeconomics\",\n    \"Social Sciences\"\n  ],\n  \"professional_law\": [\n    \"Professional Law\",\n    \"Humanities\"\n  ],\n  \"human_aging\": [\n    \"Human Aging\",\n    \"Other\"\n  ],\n  \"medical_genetics\": [\n    \"Medical Genetics\",\n    \"Other\"\n  ],\n  \"high_school_geography\": [\n    \"High School Geography\",\n    \"Social Sciences\"\n  ],\n  \"high_school_government_and_politics\": [\n    \"High School Government and Politics\",\n    \"Social Sciences\"\n  ],\n  \"anatomy\": [\n    \"Anatomy\",\n    \"STEM\"\n  ],\n  \"sociology\": [\n    \"Sociology\",\n    \"Social Sciences\"\n  ],\n  \"logical_fallacies\": [\n    \"Logical Fallacies\",\n    \"Humanities\"\n  ],\n  \"high_school_computer_science\": [\n    \"High School Computer Science\",\n    \"STEM\"\n  ],\n  \"miscellaneous\": [\n    \"Miscellaneous\",\n    \"Other\"\n  ],\n  \"high_school_world_history\": [\n    \"High School World History\",\n    \"Humanities\"\n  ],\n  \"professional_medicine\": [\n    \"Professional Medicine\",\n    \"Other\"\n  ],\n  \"high_school_biology\": [\n    \"High School Biology\",\n    \"STEM\"\n  ],\n  \"high_school_statistics\": [\n    \"High School Statistics\",\n    \"STEM\"\n  ],\n  \"college_chemistry\": [\n    \"College Chemistry\",\n    \"STEM\"\n  ],\n  \"nutrition\": [\n    \"Nutrition\",\n    \"Other\"\n  ],\n  \"econometrics\": [\n    \"Econometrics\",\n    \"Social Sciences\"\n  ],\n  \"human_sexuality\": [\n    \"Human Sexuality\",\n    \"Social Sciences\"\n  ],\n  \"security_studies\": [\n    \"Security Studies\",\n    \"Social Sciences\"\n  ],\n  \"philosophy\": [\n    \"Philosophy\",\n    \"Humanities\"\n  ],\n  \"elementary_mathematics\": [\n    \"Elementary Mathematics\",\n    \"STEM\"\n  ],\n  \"college_biology\": [\n    \"College Biology\",\n    \"STEM\"\n  ],\n  \"college_computer_science\": [\n    \"College Computer Science\",\n    \"STEM\"\n  ],\n  \"machine_learning\": [\n    \"Machine Learning\",\n    \"STEM\"\n  ],\n  \"professional_accounting\": [\n    \"Professional Accounting\",\n    \"Other\"\n  ],\n  \"college_mathematics\": [\n    \"College Mathematics\",\n    \"STEM\"\n  ],\n  \"high_school_mathematics\": [\n    \"High School Mathematics\",\n    \"STEM\"\n  ],\n  \"high_school_psychology\": [\n    \"High School Psychology\",\n    \"Social Sciences\"\n  ],\n  \"conceptual_physics\": [\n    \"Conceptual Physics\",\n    \"STEM\"\n  ],\n  \"computer_security\": [\n    \"Computer Security\",\n    \"STEM\"\n  ]\n}"
  },
  {
    "path": "src/resources/models/baichuan-13B-base/config.json",
    "content": "{\n  \"_from_model_config\": true,\n  \"architectures\": [\n    \"BaichuanForCausalLM\"\n  ],\n  \"auto_map\": {\n    \"AutoConfig\": \"configuration_baichuan.BaichuanConfig\",\n    \"AutoModelForCausalLM\": \"modeling_baichuan.BaichuanForCausalLM\"\n  },\n  \"bos_token_id\": 1,\n  \"eos_token_id\": 2,\n  \"gradient_checkpointing\": false,\n  \"hidden_act\": \"silu\",\n  \"hidden_size\": 5120,\n  \"initializer_range\": 0.02,\n  \"intermediate_size\": 13696,\n  \"model_max_length\": 4096,\n  \"model_type\": \"baichuan\",\n  \"num_attention_heads\": 40,\n  \"num_hidden_layers\": 40,\n  \"pad_token_id\": 0,\n  \"rms_norm_eps\": 1e-06,\n  \"tie_word_embeddings\": false,\n  \"torch_dtype\": \"bfloat16\",\n  \"transformers_version\": \"4.29.2\",\n  \"use_cache\": true,\n  \"vocab_size\": 64000\n}\n"
  },
  {
    "path": "src/resources/models/baichuan-13B-base/configuration_baichuan.py",
    "content": "# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.\n\nfrom transformers.configuration_utils import PretrainedConfig\n\nclass BaichuanConfig(PretrainedConfig):\n    model_type = \"baichuan\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=64000,\n        hidden_size=5120,\n        intermediate_size=13696,\n        num_hidden_layers=40,\n        num_attention_heads=40,\n        hidden_act=\"silu\",\n        model_max_length=4096,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=0,\n        bos_token_id=1,\n        eos_token_id=2,\n        tie_word_embeddings=False,\n        gradient_checkpointing=False,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.model_max_length = model_max_length\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.gradient_checkpointing = gradient_checkpointing,\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n"
  },
  {
    "path": "src/resources/models/baichuan-13B-base/modeling_baichuan.py",
    "content": "# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.\n\nimport math\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch.nn import CrossEntropyLoss\nfrom transformers import PreTrainedModel\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast\nfrom transformers.utils import logging\nfrom transformers.generation.utils import GenerationConfig\n\nfrom .configuration_baichuan import BaichuanConfig\n\nlogger = logging.get_logger(__name__)\n\n\ndef _get_interleave(n):\n    def _get_interleave_power_of_2(n):\n        start = (2 ** (-2 ** -(math.log2(n) - 3)))\n        ratio = start\n        return [start * ratio ** i for i in range(n)]\n\n    if math.log2(n).is_integer():\n        return _get_interleave_power_of_2(n)\n    else:\n        closest_power_of_2 = 2 ** math.floor(math.log2(n))\n        return _get_interleave_power_of_2(closest_power_of_2) + \\\n               _get_interleave(2 * closest_power_of_2)[0::2][:n - closest_power_of_2]\n\n\ndef _fill_with_neg_inf(t):\n    \"\"\"FP16-compatible function that fills a tensor with -inf.\"\"\"\n    return t.float().fill_(float(\"-inf\")).type_as(t)\n\n\ndef _gen_alibi_mask(n_head, max_pos, alibi_mask=None):\n    slopes = torch.Tensor(_get_interleave(n_head))\n    alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(0).unsqueeze(0).expand(\n        n_head, -1, -1)\n    alibi = alibi.view(n_head, 1, max_pos)\n    if alibi_mask is None:\n        alibi_mask = torch.triu(\n            _fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1\n        )\n    alibi_mask = alibi_mask.unsqueeze(0) + alibi\n    return alibi_mask\n\n\nclass RMSNorm(torch.nn.Module):\n    def __init__(self, hidden_size, epsilon=1e-6):\n        super().__init__()\n        self.weight = torch.nn.Parameter(torch.empty(hidden_size))\n        self.epsilon = epsilon\n\n    def forward(self, hidden_states):\n        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)\n\n        # convert into half-precision\n        if self.weight.dtype in [torch.float16, torch.bfloat16]:\n            hidden_states = hidden_states.to(self.weight.dtype)\n\n        return self.weight * hidden_states\n\n\nclass MLP(torch.nn.Module):\n    def __init__(\n            self,\n            hidden_size: int,\n            intermediate_size: int,\n            hidden_act: str,\n    ):\n        super().__init__()\n        self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)\n        self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False)\n        self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)\n        self.act_fn = ACT2FN[hidden_act]\n\n    def forward(self, x):\n        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n\n\nclass BaichuanAttention(torch.nn.Module):\n\n    def __init__(self, config: BaichuanConfig):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.hidden_size // self.num_heads\n        self.max_position_embeddings = config.model_max_length\n\n        if (self.head_dim * self.num_heads) != self.hidden_size:\n            raise ValueError(\n                f\"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}\"\n            )\n        self.W_pack = torch.nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)\n        self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()\n\n    def forward(\n            self,\n            hidden_states: torch.Tensor,\n            attention_mask: Optional[torch.Tensor] = None,\n            past_key_value: Optional[Tuple[torch.Tensor]] = None,\n            output_attentions: bool = False,\n            use_cache: bool = False,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n\n        bsz, q_len, _ = hidden_states.size()\n\n        proj = self.W_pack(hidden_states)\n        proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)\n        query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        if past_key_value is not None:\n            kv_seq_len += past_key_value[0].shape[-2]\n\n        if past_key_value is not None:\n            # reuse k, v, self_attention\n            key_states = torch.cat([past_key_value[0], key_states], dim=2)\n            value_states = torch.cat([past_key_value[1], value_states], dim=2)\n\n        past_key_value = (key_states, value_states) if use_cache else None\n\n        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n\n        if attention_mask is not None:\n            if attn_weights.size(-2) == 1:\n                attention_mask = attention_mask[:, -1:, :]\n            attn_weights = attn_weights + attention_mask.unsqueeze(0)\n            attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))\n\n        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        attn_output = attn_output.transpose(1, 2)\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)\n        attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n\nclass BaichuanLayer(torch.nn.Module):\n    def __init__(self, config: BaichuanConfig):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.self_attn = BaichuanAttention(config=config)\n        self.mlp = MLP(\n            hidden_size=self.hidden_size,\n            intermediate_size=config.intermediate_size,\n            hidden_act=config.hidden_act,\n        )\n        self.input_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)\n        self.post_attention_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)\n\n    def forward(\n            self,\n            hidden_states: torch.Tensor,\n            attention_mask: Optional[torch.Tensor] = None,\n            past_key_value: Optional[Tuple[torch.Tensor]] = None,\n            output_attentions: Optional[bool] = False,\n            use_cache: Optional[bool] = False,\n    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if use_cache:\n            outputs += (present_key_value,)\n\n        return outputs\n\n\nclass BaichuanPreTrainedModel(PreTrainedModel):\n    config_class = BaichuanConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"BaichuanLayer\"]\n    _keys_to_ignore_on_load_unexpected = [r\"decoder\\.version\"]\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, torch.nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, torch.nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, BaichuanModel):\n            module.gradient_checkpointing = value\n\n\nclass BaichuanModel(BaichuanPreTrainedModel):\n    def __init__(self, config: BaichuanConfig):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        self.n_head = config.num_attention_heads\n        self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n        self.layers = torch.nn.ModuleList([BaichuanLayer(config) for _ in range(config.num_hidden_layers)])\n        self.norm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)\n\n        self.gradient_checkpointing = config.gradient_checkpointing\n        self.post_init()\n        self.max_cache_pos = config.model_max_length\n        self.first_run = True    \n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n        \n    def set_input_embeddings(self, value):\n        self.embed_tokens = value  \n        \n    def get_alibi_mask(self, tensor, seq_length_with_past, attention_mask):\n        if self.first_run:\n            self.first_run = False\n            self.register_buffer(\n                \"future_mask\",\n                _gen_alibi_mask(self.n_head, self.max_cache_pos, attention_mask).to(tensor),\n                persistent=False\n            )\n        if seq_length_with_past > self.max_cache_pos:\n            self.max_cache_pos = seq_length_with_past\n            self.register_buffer(\n                \"future_mask\",\n                _gen_alibi_mask(self.n_head, self.max_cache_pos, attention_mask).to(tensor),\n                persistent=False\n            )\n        mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past] \n        return mask\n\n    def forward(\n            self,\n            input_ids: torch.LongTensor = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            past_key_values: Optional[List[torch.FloatTensor]] = None,\n            inputs_embeds: Optional[torch.FloatTensor] = None,\n            use_cache: Optional[bool] = False,\n            output_attentions: Optional[bool] = False,\n            output_hidden_states: Optional[bool] = False,\n            return_dict: Optional[bool] = True,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\"You cannot provide both input_ids and inputs_embeds simultaneously\")\n        elif input_ids is not None:\n            batch_size, seq_length = input_ids.shape\n        elif inputs_embeds is not None:\n            batch_size, seq_length, _ = inputs_embeds.shape\n        else:\n            raise ValueError(\"You need to provide input_ids or inputs_embeds\")\n\n        seq_length_with_past = seq_length\n\n        if past_key_values is not None:\n            past_key_values_length = past_key_values[0][0].shape[2]\n            seq_length_with_past = seq_length_with_past + past_key_values_length\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        # embed positions\n        attention_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past, attention_mask)\n\n        hidden_states = inputs_embeds\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                logger.warning_once(\n                    \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                )\n                use_cache = False\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        next_decoder_cache = () if use_cache else None\n\n        for idx, decoder_layer in enumerate(self.layers):\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            past_key_value = past_key_values[idx] if past_key_values is not None else None\n\n            if self.gradient_checkpointing and self.training:\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        # None for past_key_value\n                        return module(*inputs, output_attentions, None)\n\n                    return custom_forward\n\n                layer_outputs = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(decoder_layer),\n                    hidden_states,\n                    attention_mask,\n                    None,\n                )\n            else:\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=attention_mask,\n                    past_key_value=past_key_value,\n                    output_attentions=output_attentions,\n                    use_cache=use_cache,\n                )\n\n            hidden_states = layer_outputs[0]\n\n            if use_cache:\n                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = next_decoder_cache if use_cache else None\n        if not return_dict:\n            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n    \n\nclass BaichuanForCausalLM(BaichuanPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = BaichuanModel(config)\n        self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def forward(\n            self,\n            input_ids: torch.LongTensor = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            past_key_values: Optional[List[torch.FloatTensor]] = None,\n            inputs_embeds: Optional[torch.FloatTensor] = None,\n            labels: Optional[torch.LongTensor] = None,\n            use_cache: Optional[bool] = None,\n            output_attentions: Optional[bool] = False,\n            output_hidden_states: Optional[bool] = False,\n            return_dict: Optional[bool] = True,\n            **kwargs\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        hidden_states = outputs[0]\n        logits = self.lm_head(hidden_states)\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n    def prepare_inputs_for_generation(\n            self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs\n    ):  \n        if past_key_values:\n            input_ids = input_ids[:, -1:]\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_key_values is None:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        model_inputs.update(\n            {   \n                \"past_key_values\": past_key_values,\n                \"use_cache\": kwargs.get(\"use_cache\"),\n            }   \n        )   \n        return model_inputs\n\n    @staticmethod\n    def _reorder_cache(past_key_values, beam_idx):\n        return tuple(\n            tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)\n            for layer_past in past_key_values\n        )\n\n    def quantize(self, bits: int):\n        try:\n            from .quantizer import QLinear\n        except ImportError:\n            raise ImportError(\n                f\"Needs QLinear to run quantize.\"\n            )\n        \n        for layer in self.model.layers:\n            layer.self_attn.W_pack = QLinear(\n                bits=bits,\n                weight=layer.self_attn.W_pack.weight,\n                bias = None,\n            )\n            layer.self_attn.o_proj = QLinear(\n                bits=bits,\n                weight=layer.self_attn.o_proj.weight,\n                bias = None,\n            )\n            layer.mlp.gate_proj = QLinear(\n                bits=bits,\n                weight=layer.mlp.gate_proj.weight,\n                bias = None,\n            )\n            layer.mlp.down_proj = QLinear(\n                bits=bits,\n                weight=layer.mlp.down_proj.weight,\n                bias = None,\n            )\n            layer.mlp.up_proj = QLinear(\n                bits=bits,\n                weight=layer.mlp.up_proj.weight,\n                bias = None,\n            )\n        return self \n\n    def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):\n        max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens\n        max_input_tokens = self.config.model_max_length - max_new_tokens\n        max_input_tokens = max(self.config.model_max_length // 2, max_input_tokens)\n        total_input, round_input = [], []\n        for i, message in enumerate(messages[::-1]):\n            content_tokens = tokenizer.encode(message['content'])\n            if message['role'] == 'user':\n                round_input = [self.generation_config.user_token_id] + content_tokens + round_input\n                if total_input and len(total_input) + len(round_input) > max_input_tokens:\n                    break\n                else:\n                    total_input = round_input + total_input\n                    if len(total_input) >= max_input_tokens:\n                        break\n                    else:\n                        round_input = []\n            elif message['role'] == 'assistant':\n                round_input = [\n                    self.generation_config.assistant_token_id\n                ] + content_tokens + [\n                    self.generation_config.eos_token_id\n                ] + round_input\n            else:\n                raise ValueError(f\"message role not supported yet: {message['role']}\")\n        total_input = total_input[-max_input_tokens:]  # truncate left\n        total_input.append(self.generation_config.assistant_token_id)\n        total_input = torch.LongTensor([total_input]).to(self.device)\n        return total_input\n\n    @torch.no_grad()\n    def chat(self, tokenizer, messages: List[dict], stream=False,\n             generation_config: Optional[GenerationConfig]=None):\n        generation_config = generation_config or self.generation_config\n        input_ids = self._build_chat_input(tokenizer, messages, generation_config.max_new_tokens)\n        if stream:\n            from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig\n            self.__class__.generate = NewGenerationMixin.generate\n            self.__class__.sample_stream = NewGenerationMixin.sample_stream\n            stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)\n\n            def stream_generator():\n                outputs = []\n                for token in self.generate(input_ids, generation_config=stream_config):\n                    outputs.append(token.item())\n                    yield tokenizer.decode(outputs, skip_special_tokens=True)\n\n            return stream_generator()\n        else:\n            self.__class__.generate = PreTrainedModel.generate  # disable stream\n            outputs = self.generate(input_ids, generation_config=generation_config)\n            response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)\n            return response\n"
  },
  {
    "path": "src/resources/models/baichuan-13B-base/tokenization_baichuan.py",
    "content": "# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.\n\nimport os\nfrom shutil import copyfile\nfrom typing import Any, Dict, List, Optional, Tuple\n\nimport sentencepiece as spm\nfrom transformers.tokenization_utils import AddedToken, PreTrainedTokenizer\nfrom transformers.utils import logging\n\n\nlogger = logging.get_logger(__name__)\n\nVOCAB_FILES_NAMES = {\"vocab_file\": \"tokenizer.model\"}\n\nPRETRAINED_VOCAB_FILES_MAP = {\n    \"vocab_file\": {},\n    \"tokenizer_file\": {},\n}\nPRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {}\n\n\nclass BaichuanTokenizer(PreTrainedTokenizer):\n    \"\"\"\n    Construct a Baichuan tokenizer. Based on byte-level Byte-Pair-Encoding.\n\n    Args:\n        vocab_file (`str`):\n            Path to the vocabulary file.\n    \"\"\"\n\n    vocab_files_names = VOCAB_FILES_NAMES\n    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP\n    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES\n    model_input_names = [\"input_ids\", \"attention_mask\"]\n\n    def __init__(\n        self,\n        vocab_file,\n        unk_token=\"<unk>\",\n        bos_token=\"<s>\",\n        eos_token=\"</s>\",\n        pad_token=None,\n        sp_model_kwargs: Optional[Dict[str, Any]] = None,\n        add_bos_token=True,\n        add_eos_token=False,\n        clean_up_tokenization_spaces=False,\n        **kwargs,\n    ):\n        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs\n        bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token\n        eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token\n        unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token\n        pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token\n        super().__init__(\n            bos_token=bos_token,\n            eos_token=eos_token,\n            unk_token=unk_token,\n            pad_token=pad_token,\n            add_bos_token=add_bos_token,\n            add_eos_token=add_eos_token,\n            sp_model_kwargs=self.sp_model_kwargs,\n            clean_up_tokenization_spaces=clean_up_tokenization_spaces,\n            **kwargs,\n        )\n        self.vocab_file = vocab_file\n        self.add_bos_token = add_bos_token\n        self.add_eos_token = add_eos_token\n        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)\n        self.sp_model.Load(vocab_file)\n\n    def __getstate__(self):\n        state = self.__dict__.copy()\n        state[\"sp_model\"] = None\n        return state\n\n    def __setstate__(self, d):\n        self.__dict__ = d\n        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)\n        self.sp_model.Load(self.vocab_file)\n\n    @property\n    def vocab_size(self):\n        \"\"\"Returns vocab size\"\"\"\n        return self.sp_model.get_piece_size()\n\n    def get_vocab(self):\n        \"\"\"Returns vocab as a dict\"\"\"\n        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}\n        vocab.update(self.added_tokens_encoder)\n        return vocab\n\n    def _tokenize(self, text):\n        \"\"\"Returns a tokenized string.\"\"\"\n        return self.sp_model.encode(text, out_type=str)\n\n    def _convert_token_to_id(self, token):\n        \"\"\"Converts a token (str) in an id using the vocab.\"\"\"\n        return self.sp_model.piece_to_id(token)\n\n    def _convert_id_to_token(self, index):\n        \"\"\"Converts an index (integer) in a token (str) using the vocab.\"\"\"\n        token = self.sp_model.IdToPiece(index)\n        return token\n\n    def convert_tokens_to_string(self, tokens):\n        \"\"\"Converts a sequence of tokens (string) in a single string.\"\"\"\n        current_sub_tokens = []\n        out_string = \"\"\n        prev_is_special = False\n        for i, token in enumerate(tokens):\n            # make sure that special tokens are not decoded using sentencepiece model\n            if token in self.all_special_tokens:\n                if not prev_is_special and i != 0:\n                    out_string += \" \"\n                out_string += self.sp_model.decode(current_sub_tokens) + token\n                prev_is_special = True\n                current_sub_tokens = []\n            else:\n                current_sub_tokens.append(token)\n                prev_is_special = False\n        out_string += self.sp_model.decode(current_sub_tokens)\n        return out_string\n\n    def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:\n        \"\"\"\n        Save the vocabulary and special tokens file to a directory.\n\n        Args:\n            save_directory (`str`):\n                The directory in which to save the vocabulary.\n\n        Returns:\n            `Tuple(str)`: Paths to the files saved.\n        \"\"\"\n        if not os.path.isdir(save_directory):\n            logger.error(f\"Vocabulary path ({save_directory}) should be a directory\")\n            return\n        out_vocab_file = os.path.join(\n            save_directory, (filename_prefix + \"-\" if filename_prefix else \"\") + VOCAB_FILES_NAMES[\"vocab_file\"]\n        )\n\n        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):\n            copyfile(self.vocab_file, out_vocab_file)\n        elif not os.path.isfile(self.vocab_file):\n            with open(out_vocab_file, \"wb\") as fi:\n                content_spiece_model = self.sp_model.serialized_model_proto()\n                fi.write(content_spiece_model)\n\n        return (out_vocab_file,)\n\n    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):\n        bos_token_id = [self.bos_token_id] if self.add_bos_token else []\n        eos_token_id = [self.eos_token_id] if self.add_eos_token else []\n\n        output = bos_token_id + token_ids_0 + eos_token_id\n\n        if token_ids_1 is not None:\n            output = output + bos_token_id + token_ids_1 + eos_token_id\n\n        return output\n\n    def get_special_tokens_mask(\n        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False\n    ) -> List[int]:\n        \"\"\"\n        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding\n        special tokens using the tokenizer `prepare_for_model` method.\n\n        Args:\n            token_ids_0 (`List[int]`):\n                List of IDs.\n            token_ids_1 (`List[int]`, *optional*):\n                Optional second list of IDs for sequence pairs.\n            already_has_special_tokens (`bool`, *optional*, defaults to `False`):\n                Whether or not the token list is already formatted with special tokens for the model.\n\n        Returns:\n            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.\n        \"\"\"\n        if already_has_special_tokens:\n            return super().get_special_tokens_mask(\n                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True\n            )\n\n        bos_token_id = [1] if self.add_bos_token else []\n        eos_token_id = [1] if self.add_eos_token else []\n\n        if token_ids_1 is None:\n            return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id\n        return (\n            bos_token_id\n            + ([0] * len(token_ids_0))\n            + eos_token_id\n            + bos_token_id\n            + ([0] * len(token_ids_1))\n            + eos_token_id\n        )\n\n    def create_token_type_ids_from_sequences(\n        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None\n    ) -> List[int]:\n        \"\"\"\n        Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT\n        sequence pair mask has the following format:\n\n        ```\n        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1\n        | first sequence    | second sequence |\n        ```\n\n        if token_ids_1 is None, only returns the first portion of the mask (0s).\n\n        Args:\n            token_ids_0 (`List[int]`):\n                List of ids.\n            token_ids_1 (`List[int]`, *optional*):\n                Optional second list of IDs for sequence pairs.\n\n        Returns:\n            `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).\n        \"\"\"\n        bos_token_id = [self.bos_token_id] if self.add_bos_token else []\n        eos_token_id = [self.eos_token_id] if self.add_eos_token else []\n\n        output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)\n\n        if token_ids_1 is not None:\n            output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)\n\n        return output\n\n"
  },
  {
    "path": "src/resources/models/baichuan-13B-base/tokenizer_config.json",
    "content": "{\n  \"add_bos_token\": true,\n  \"add_eos_token\": false,\n  \"auto_map\": {\n    \"AutoTokenizer\": [\n      \"tokenization_baichuan.BaichuanTokenizer\",\n      null\n    ]\n  },\n  \"bos_token\": {\n    \"__type\": \"AddedToken\",\n    \"content\": \"<s>\",\n    \"lstrip\": false,\n    \"normalized\": true,\n    \"rstrip\": false,\n    \"single_word\": true\n  },\n  \"clean_up_tokenization_spaces\": false,\n  \"eos_token\": {\n    \"__type\": \"AddedToken\",\n    \"content\": \"</s>\",\n    \"lstrip\": false,\n    \"normalized\": true,\n    \"rstrip\": false,\n    \"single_word\": true\n  },\n  \"model_max_length\": 4096,\n  \"pad_token\": {\n    \"__type\": \"AddedToken\",\n    \"content\": \"<unk>\",\n    \"lstrip\": false,\n    \"normalized\": true,\n    \"rstrip\": false,\n    \"single_word\": true\n  },\n  \"sp_model_kwargs\": {},\n  \"tokenizer_class\": \"BaichuanTokenizer\",\n  \"unk_token\": {\n    \"__type\": \"AddedToken\",\n    \"content\": \"<unk>\",\n    \"lstrip\": false,\n    \"normalized\": true,\n    \"rstrip\": false,\n    \"single_word\": true\n  }\n}\n"
  },
  {
    "path": "src/resources/models/baichuan-13B-chat/config.json",
    "content": "{\n  \"_from_model_config\": true,\n  \"architectures\": [\n    \"BaichuanForCausalLM\"\n  ],\n  \"auto_map\": {\n    \"AutoConfig\": \"configuration_baichuan.BaichuanConfig\",\n    \"AutoModelForCausalLM\": \"modeling_baichuan.BaichuanForCausalLM\"\n  },\n  \"bos_token_id\": 1,\n  \"eos_token_id\": 2,\n  \"hidden_act\": \"silu\",\n  \"hidden_size\": 5120,\n  \"initializer_range\": 0.02,\n  \"intermediate_size\": 13696,\n  \"model_max_length\": 4096,\n  \"model_type\": \"baichuan\",\n  \"num_attention_heads\": 40,\n  \"num_hidden_layers\": 40,\n  \"pad_token_id\": 0,\n  \"rms_norm_eps\": 1e-06,\n  \"tie_word_embeddings\": false,\n  \"torch_dtype\": \"bfloat16\",\n  \"transformers_version\": \"4.29.2\",\n  \"use_cache\": true,\n  \"vocab_size\": 64000\n}\n"
  },
  {
    "path": "src/resources/models/baichuan-13B-chat/configuration_baichuan.py",
    "content": "# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.\n\nfrom transformers.configuration_utils import PretrainedConfig\n\nclass BaichuanConfig(PretrainedConfig):\n    model_type = \"baichuan\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=64000,\n        hidden_size=5120,\n        intermediate_size=13696,\n        num_hidden_layers=40,\n        num_attention_heads=40,\n        hidden_act=\"silu\",\n        model_max_length=4096,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=0,\n        bos_token_id=1,\n        eos_token_id=2,\n        tie_word_embeddings=False,\n        gradient_checkpointing=False,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.model_max_length = model_max_length\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.gradient_checkpointing = gradient_checkpointing,\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n"
  },
  {
    "path": "src/resources/models/baichuan-13B-chat/modeling_baichuan.py",
    "content": "# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.\n\nimport math\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch.nn import CrossEntropyLoss\nfrom transformers import PreTrainedModel\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast\nfrom transformers.utils import logging\nfrom transformers.generation.utils import GenerationConfig\n\nfrom .configuration_baichuan import BaichuanConfig\n\nlogger = logging.get_logger(__name__)\n\n\ndef _get_interleave(n):\n    def _get_interleave_power_of_2(n):\n        start = (2 ** (-2 ** -(math.log2(n) - 3)))\n        ratio = start\n        return [start * ratio ** i for i in range(n)]\n\n    if math.log2(n).is_integer():\n        return _get_interleave_power_of_2(n)\n    else:\n        closest_power_of_2 = 2 ** math.floor(math.log2(n))\n        return _get_interleave_power_of_2(closest_power_of_2) + \\\n               _get_interleave(2 * closest_power_of_2)[0::2][:n - closest_power_of_2]\n\n\ndef _fill_with_neg_inf(t):\n    \"\"\"FP16-compatible function that fills a tensor with -inf.\"\"\"\n    return t.float().fill_(float(\"-inf\")).type_as(t)\n\n\ndef _gen_alibi_mask(n_head, max_pos, alibi_mask=None):\n    slopes = torch.Tensor(_get_interleave(n_head))\n    alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(0).unsqueeze(0).expand(\n        n_head, -1, -1)\n    alibi = alibi.view(n_head, 1, max_pos)\n    if alibi_mask is None:\n        alibi_mask = torch.triu(\n            _fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1\n        )\n    alibi_mask = alibi_mask.unsqueeze(0) + alibi\n    return alibi_mask\n\n\nclass RMSNorm(torch.nn.Module):\n    def __init__(self, hidden_size, epsilon=1e-6):\n        super().__init__()\n        self.weight = torch.nn.Parameter(torch.empty(hidden_size))\n        self.epsilon = epsilon\n\n    def forward(self, hidden_states):\n        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)\n\n        # convert into half-precision\n        if self.weight.dtype in [torch.float16, torch.bfloat16]:\n            hidden_states = hidden_states.to(self.weight.dtype)\n\n        return self.weight * hidden_states\n\n\nclass MLP(torch.nn.Module):\n    def __init__(\n            self,\n            hidden_size: int,\n            intermediate_size: int,\n            hidden_act: str,\n    ):\n        super().__init__()\n        self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)\n        self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False)\n        self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)\n        self.act_fn = ACT2FN[hidden_act]\n\n    def forward(self, x):\n        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n\n\nclass BaichuanAttention(torch.nn.Module):\n\n    def __init__(self, config: BaichuanConfig):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.hidden_size // self.num_heads\n        self.max_position_embeddings = config.model_max_length\n\n        if (self.head_dim * self.num_heads) != self.hidden_size:\n            raise ValueError(\n                f\"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}\"\n            )\n        self.W_pack = torch.nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)\n        self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()\n\n    def forward(\n            self,\n            hidden_states: torch.Tensor,\n            attention_mask: Optional[torch.Tensor] = None,\n            past_key_value: Optional[Tuple[torch.Tensor]] = None,\n            output_attentions: bool = False,\n            use_cache: bool = False,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n\n        bsz, q_len, _ = hidden_states.size()\n\n        proj = self.W_pack(hidden_states)\n        proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)\n        query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        if past_key_value is not None:\n            kv_seq_len += past_key_value[0].shape[-2]\n\n        if past_key_value is not None:\n            # reuse k, v, self_attention\n            key_states = torch.cat([past_key_value[0], key_states], dim=2)\n            value_states = torch.cat([past_key_value[1], value_states], dim=2)\n\n        past_key_value = (key_states, value_states) if use_cache else None\n\n        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n\n        if attention_mask is not None:\n            if attn_weights.size(-2) == 1:\n                attention_mask = attention_mask[:, -1:, :]\n            attn_weights = attn_weights + attention_mask.unsqueeze(0)\n            attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))\n\n        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        attn_output = attn_output.transpose(1, 2)\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)\n        attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n\nclass BaichuanLayer(torch.nn.Module):\n    def __init__(self, config: BaichuanConfig):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.self_attn = BaichuanAttention(config=config)\n        self.mlp = MLP(\n            hidden_size=self.hidden_size,\n            intermediate_size=config.intermediate_size,\n            hidden_act=config.hidden_act,\n        )\n        self.input_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)\n        self.post_attention_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)\n\n    def forward(\n            self,\n            hidden_states: torch.Tensor,\n            attention_mask: Optional[torch.Tensor] = None,\n            past_key_value: Optional[Tuple[torch.Tensor]] = None,\n            output_attentions: Optional[bool] = False,\n            use_cache: Optional[bool] = False,\n    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if use_cache:\n            outputs += (present_key_value,)\n\n        return outputs\n\n\nclass BaichuanPreTrainedModel(PreTrainedModel):\n    config_class = BaichuanConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"BaichuanLayer\"]\n    _keys_to_ignore_on_load_unexpected = [r\"decoder\\.version\"]\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, torch.nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, torch.nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, BaichuanModel):\n            module.gradient_checkpointing = value\n\n\nclass BaichuanModel(BaichuanPreTrainedModel):\n    def __init__(self, config: BaichuanConfig):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        self.n_head = config.num_attention_heads\n        self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n        self.layers = torch.nn.ModuleList([BaichuanLayer(config) for _ in range(config.num_hidden_layers)])\n        self.norm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)\n\n        self.gradient_checkpointing = config.gradient_checkpointing\n        self.post_init()\n        self.max_cache_pos = config.model_max_length\n        self.first_run = True    \n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n        \n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n        \n    def get_alibi_mask(self, tensor, seq_length_with_past, attention_mask):\n        if self.first_run:\n            self.first_run = False\n            self.register_buffer(\n                \"future_mask\",\n                _gen_alibi_mask(self.n_head, self.max_cache_pos, attention_mask).to(tensor),\n                persistent=False\n            )\n        if seq_length_with_past > self.max_cache_pos:\n            self.max_cache_pos = seq_length_with_past\n            self.register_buffer(\n                \"future_mask\",\n                _gen_alibi_mask(self.n_head, self.max_cache_pos, attention_mask).to(tensor),\n                persistent=False\n            )\n        mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past] \n        return mask\n\n    def forward(\n            self,\n            input_ids: torch.LongTensor = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            past_key_values: Optional[List[torch.FloatTensor]] = None,\n            inputs_embeds: Optional[torch.FloatTensor] = None,\n            use_cache: Optional[bool] = False,\n            output_attentions: Optional[bool] = False,\n            output_hidden_states: Optional[bool] = False,\n            return_dict: Optional[bool] = True,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\"You cannot provide both input_ids and inputs_embeds simultaneously\")\n        elif input_ids is not None:\n            batch_size, seq_length = input_ids.shape\n        elif inputs_embeds is not None:\n            batch_size, seq_length, _ = inputs_embeds.shape\n        else:\n            raise ValueError(\"You need to provide input_ids or inputs_embeds\")\n\n        seq_length_with_past = seq_length\n\n        if past_key_values is not None:\n            past_key_values_length = past_key_values[0][0].shape[2]\n            seq_length_with_past = seq_length_with_past + past_key_values_length\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        # embed positions\n        attention_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past, attention_mask)\n\n        hidden_states = inputs_embeds\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                logger.warning_once(\n                    \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                )\n                use_cache = False\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        next_decoder_cache = () if use_cache else None\n\n        for idx, decoder_layer in enumerate(self.layers):\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            past_key_value = past_key_values[idx] if past_key_values is not None else None\n\n            if self.gradient_checkpointing and self.training:\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        # None for past_key_value\n                        return module(*inputs, output_attentions, None)\n\n                    return custom_forward\n\n                layer_outputs = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(decoder_layer),\n                    hidden_states,\n                    attention_mask,\n                    None,\n                )\n            else:\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=attention_mask,\n                    past_key_value=past_key_value,\n                    output_attentions=output_attentions,\n                    use_cache=use_cache,\n                )\n\n            hidden_states = layer_outputs[0]\n\n            if use_cache:\n                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = next_decoder_cache if use_cache else None\n        if not return_dict:\n            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n    \n\nclass BaichuanForCausalLM(BaichuanPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = BaichuanModel(config)\n        self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def forward(\n            self,\n            input_ids: torch.LongTensor = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            past_key_values: Optional[List[torch.FloatTensor]] = None,\n            inputs_embeds: Optional[torch.FloatTensor] = None,\n            labels: Optional[torch.LongTensor] = None,\n            use_cache: Optional[bool] = None,\n            output_attentions: Optional[bool] = False,\n            output_hidden_states: Optional[bool] = False,\n            return_dict: Optional[bool] = True,\n            **kwargs\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )   \n\n        hidden_states = outputs[0]\n        logits = self.lm_head(hidden_states)\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n    def prepare_inputs_for_generation(\n            self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs\n    ):  \n        if past_key_values:\n            input_ids = input_ids[:, -1:]\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_key_values is None:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        model_inputs.update(\n            {   \n                \"past_key_values\": past_key_values,\n                \"use_cache\": kwargs.get(\"use_cache\"),\n            }   \n        )   \n        return model_inputs\n\n    @staticmethod\n    def _reorder_cache(past_key_values, beam_idx):\n        return tuple(\n            tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)\n            for layer_past in past_key_values\n        )\n\n\n    def quantize(self, bits: int):\n        try:\n            from .quantizer import QLinear\n        except ImportError:\n            raise ImportError(\n                f\"Needs QLinear to run quantize.\"\n            )\n        \n        for layer in self.model.layers:\n            layer.self_attn.W_pack = QLinear(\n                bits=bits,\n                weight=layer.self_attn.W_pack.weight,\n                bias = None,\n            )\n            layer.self_attn.o_proj = QLinear(\n                bits=bits,\n                weight=layer.self_attn.o_proj.weight,\n                bias = None,\n            )\n            layer.mlp.gate_proj = QLinear(\n                bits=bits,\n                weight=layer.mlp.gate_proj.weight,\n                bias = None,\n            )\n            layer.mlp.down_proj = QLinear(\n                bits=bits,\n                weight=layer.mlp.down_proj.weight,\n                bias = None,\n            )\n            layer.mlp.up_proj = QLinear(\n                bits=bits,\n                weight=layer.mlp.up_proj.weight,\n                bias = None,\n            )\n        return self \n\n    def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):\n        max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens\n        max_input_tokens = self.config.model_max_length - max_new_tokens\n        max_input_tokens = max(self.config.model_max_length // 2, max_input_tokens)\n        total_input, round_input = [], []\n        for i, message in enumerate(messages[::-1]):\n            content_tokens = tokenizer.encode(message['content'])\n            if message['role'] == 'user':\n                round_input = [self.generation_config.user_token_id] + content_tokens + round_input\n                if total_input and len(total_input) + len(round_input) > max_input_tokens:\n                    break\n                else:\n                    total_input = round_input + total_input\n                    if len(total_input) >= max_input_tokens:\n                        break\n                    else:\n                        round_input = []\n            elif message['role'] == 'assistant':\n                round_input = [\n                    self.generation_config.assistant_token_id\n                ] + content_tokens + [\n                    self.generation_config.eos_token_id\n                ] + round_input\n            else:\n                raise ValueError(f\"message role not supported yet: {message['role']}\")\n        total_input = total_input[-max_input_tokens:]  # truncate left\n        total_input.append(self.generation_config.assistant_token_id)\n        total_input = torch.LongTensor([total_input]).to(self.device)\n        return total_input\n\n    @torch.no_grad()\n    def chat(self, tokenizer, messages: List[dict], stream=False,\n             generation_config: Optional[GenerationConfig]=None):\n        generation_config = generation_config or self.generation_config\n        input_ids = self._build_chat_input(tokenizer, messages, generation_config.max_new_tokens)\n        if stream:\n            from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig\n            self.__class__.generate = NewGenerationMixin.generate\n            self.__class__.sample_stream = NewGenerationMixin.sample_stream\n            stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)\n\n            def stream_generator():\n                outputs = []\n                for token in self.generate(input_ids, generation_config=stream_config):\n                    outputs.append(token.item())\n                    yield tokenizer.decode(outputs, skip_special_tokens=True)\n\n            return stream_generator()\n        else:\n            self.__class__.generate = PreTrainedModel.generate  # disable stream\n            outputs = self.generate(input_ids, generation_config=generation_config)\n            response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)\n            return response\n"
  },
  {
    "path": "src/resources/models/baichuan-13B-chat/tokenization_baichuan.py",
    "content": "# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.\n\nimport os\nfrom shutil import copyfile\nfrom typing import Any, Dict, List, Optional, Tuple\n\nimport sentencepiece as spm\nfrom transformers.tokenization_utils import AddedToken, PreTrainedTokenizer\nfrom transformers.utils import logging\n\n\nlogger = logging.get_logger(__name__)\n\nVOCAB_FILES_NAMES = {\"vocab_file\": \"tokenizer.model\"}\n\nPRETRAINED_VOCAB_FILES_MAP = {\n    \"vocab_file\": {},\n    \"tokenizer_file\": {},\n}\nPRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {}\n\n\nclass BaichuanTokenizer(PreTrainedTokenizer):\n    \"\"\"\n    Construct a Baichuan tokenizer. Based on byte-level Byte-Pair-Encoding.\n\n    Args:\n        vocab_file (`str`):\n            Path to the vocabulary file.\n    \"\"\"\n\n    vocab_files_names = VOCAB_FILES_NAMES\n    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP\n    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES\n    model_input_names = [\"input_ids\", \"attention_mask\"]\n\n    def __init__(\n        self,\n        vocab_file,\n        unk_token=\"<unk>\",\n        bos_token=\"<s>\",\n        eos_token=\"</s>\",\n        pad_token=None,\n        sp_model_kwargs: Optional[Dict[str, Any]] = None,\n        add_bos_token=True,\n        add_eos_token=False,\n        clean_up_tokenization_spaces=False,\n        **kwargs,\n    ):\n        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs\n        bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token\n        eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token\n        unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token\n        pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token\n        super().__init__(\n            bos_token=bos_token,\n            eos_token=eos_token,\n            unk_token=unk_token,\n            pad_token=pad_token,\n            add_bos_token=add_bos_token,\n            add_eos_token=add_eos_token,\n            sp_model_kwargs=self.sp_model_kwargs,\n            clean_up_tokenization_spaces=clean_up_tokenization_spaces,\n            **kwargs,\n        )\n        self.vocab_file = vocab_file\n        self.add_bos_token = add_bos_token\n        self.add_eos_token = add_eos_token\n        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)\n        self.sp_model.Load(vocab_file)\n\n    def __getstate__(self):\n        state = self.__dict__.copy()\n        state[\"sp_model\"] = None\n        return state\n\n    def __setstate__(self, d):\n        self.__dict__ = d\n        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)\n        self.sp_model.Load(self.vocab_file)\n\n    @property\n    def vocab_size(self):\n        \"\"\"Returns vocab size\"\"\"\n        return self.sp_model.get_piece_size()\n\n    def get_vocab(self):\n        \"\"\"Returns vocab as a dict\"\"\"\n        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}\n        vocab.update(self.added_tokens_encoder)\n        return vocab\n\n    def _tokenize(self, text):\n        \"\"\"Returns a tokenized string.\"\"\"\n        return self.sp_model.encode(text, out_type=str)\n\n    def _convert_token_to_id(self, token):\n        \"\"\"Converts a token (str) in an id using the vocab.\"\"\"\n        return self.sp_model.piece_to_id(token)\n\n    def _convert_id_to_token(self, index):\n        \"\"\"Converts an index (integer) in a token (str) using the vocab.\"\"\"\n        token = self.sp_model.IdToPiece(index)\n        return token\n\n    def convert_tokens_to_string(self, tokens):\n        \"\"\"Converts a sequence of tokens (string) in a single string.\"\"\"\n        current_sub_tokens = []\n        out_string = \"\"\n        prev_is_special = False\n        for i, token in enumerate(tokens):\n            # make sure that special tokens are not decoded using sentencepiece model\n            if token in self.all_special_tokens:\n                if not prev_is_special and i != 0:\n                    out_string += \" \"\n                out_string += self.sp_model.decode(current_sub_tokens) + token\n                prev_is_special = True\n                current_sub_tokens = []\n            else:\n                current_sub_tokens.append(token)\n                prev_is_special = False\n        out_string += self.sp_model.decode(current_sub_tokens)\n        return out_string\n\n    def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:\n        \"\"\"\n        Save the vocabulary and special tokens file to a directory.\n\n        Args:\n            save_directory (`str`):\n                The directory in which to save the vocabulary.\n\n        Returns:\n            `Tuple(str)`: Paths to the files saved.\n        \"\"\"\n        if not os.path.isdir(save_directory):\n            logger.error(f\"Vocabulary path ({save_directory}) should be a directory\")\n            return\n        out_vocab_file = os.path.join(\n            save_directory, (filename_prefix + \"-\" if filename_prefix else \"\") + VOCAB_FILES_NAMES[\"vocab_file\"]\n        )\n\n        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):\n            copyfile(self.vocab_file, out_vocab_file)\n        elif not os.path.isfile(self.vocab_file):\n            with open(out_vocab_file, \"wb\") as fi:\n                content_spiece_model = self.sp_model.serialized_model_proto()\n                fi.write(content_spiece_model)\n\n        return (out_vocab_file,)\n\n    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):\n        bos_token_id = [self.bos_token_id] if self.add_bos_token else []\n        eos_token_id = [self.eos_token_id] if self.add_eos_token else []\n\n        output = bos_token_id + token_ids_0 + eos_token_id\n\n        if token_ids_1 is not None:\n            output = output + bos_token_id + token_ids_1 + eos_token_id\n\n        return output\n\n    def get_special_tokens_mask(\n        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False\n    ) -> List[int]:\n        \"\"\"\n        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding\n        special tokens using the tokenizer `prepare_for_model` method.\n\n        Args:\n            token_ids_0 (`List[int]`):\n                List of IDs.\n            token_ids_1 (`List[int]`, *optional*):\n                Optional second list of IDs for sequence pairs.\n            already_has_special_tokens (`bool`, *optional*, defaults to `False`):\n                Whether or not the token list is already formatted with special tokens for the model.\n\n        Returns:\n            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.\n        \"\"\"\n        if already_has_special_tokens:\n            return super().get_special_tokens_mask(\n                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True\n            )\n\n        bos_token_id = [1] if self.add_bos_token else []\n        eos_token_id = [1] if self.add_eos_token else []\n\n        if token_ids_1 is None:\n            return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id\n        return (\n            bos_token_id\n            + ([0] * len(token_ids_0))\n            + eos_token_id\n            + bos_token_id\n            + ([0] * len(token_ids_1))\n            + eos_token_id\n        )\n\n    def create_token_type_ids_from_sequences(\n        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None\n    ) -> List[int]:\n        \"\"\"\n        Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT\n        sequence pair mask has the following format:\n\n        ```\n        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1\n        | first sequence    | second sequence |\n        ```\n\n        if token_ids_1 is None, only returns the first portion of the mask (0s).\n\n        Args:\n            token_ids_0 (`List[int]`):\n                List of ids.\n            token_ids_1 (`List[int]`, *optional*):\n                Optional second list of IDs for sequence pairs.\n\n        Returns:\n            `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).\n        \"\"\"\n        bos_token_id = [self.bos_token_id] if self.add_bos_token else []\n        eos_token_id = [self.eos_token_id] if self.add_eos_token else []\n\n        output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)\n\n        if token_ids_1 is not None:\n            output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)\n\n        return output\n\n"
  },
  {
    "path": "src/resources/models/baichuan-13B-chat/tokenizer_config.json",
    "content": "{\n  \"add_bos_token\": true,\n  \"add_eos_token\": false,\n  \"auto_map\": {\n    \"AutoTokenizer\": [\n      \"tokenization_baichuan.BaichuanTokenizer\",\n      null\n    ]\n  },\n  \"bos_token\": {\n    \"__type\": \"AddedToken\",\n    \"content\": \"<s>\",\n    \"lstrip\": false,\n    \"normalized\": true,\n    \"rstrip\": false,\n    \"single_word\": true\n  },\n  \"clean_up_tokenization_spaces\": false,\n  \"eos_token\": {\n    \"__type\": \"AddedToken\",\n    \"content\": \"</s>\",\n    \"lstrip\": false,\n    \"normalized\": true,\n    \"rstrip\": false,\n    \"single_word\": true\n  },\n  \"model_max_length\": 4096,\n  \"pad_token\": {\n    \"__type\": \"AddedToken\",\n    \"content\": \"<unk>\",\n    \"lstrip\": false,\n    \"normalized\": true,\n    \"rstrip\": false,\n    \"single_word\": true\n  },\n  \"sp_model_kwargs\": {},\n  \"tokenizer_class\": \"BaichuanTokenizer\",\n  \"unk_token\": {\n    \"__type\": \"AddedToken\",\n    \"content\": \"<unk>\",\n    \"lstrip\": false,\n    \"normalized\": true,\n    \"rstrip\": false,\n    \"single_word\": true\n  }\n}\n"
  },
  {
    "path": "src/resources/models/baichuan-7B/config.json",
    "content": "{\n  \"architectures\": [\n    \"BaiChuanForCausalLM\"\n  ],\n  \"auto_map\": {\n    \"AutoConfig\": \"configuration_baichuan.BaiChuanConfig\",\n    \"AutoModelForCausalLM\": \"modeling_baichuan.BaiChuanForCausalLM\"\n  },\n  \"bos_token_id\": 1,\n  \"eos_token_id\": 2,\n  \"hidden_act\": \"silu\",\n  \"hidden_size\": 4096,\n  \"initializer_range\": 0.02,\n  \"intermediate_size\": 11008,\n  \"max_position_embeddings\": 4096,\n  \"model_type\": \"baichuan\",\n  \"num_attention_heads\": 32,\n  \"num_hidden_layers\": 32,\n  \"pad_token_id\": 0,\n  \"rms_norm_eps\": 1e-06,\n  \"tie_word_embeddings\": false,\n  \"torch_dtype\": \"float32\",\n  \"transformers_version\": \"4.29.1\",\n  \"use_cache\": true,\n  \"vocab_size\": 64000\n}\n"
  },
  {
    "path": "src/resources/models/baichuan-7B/configuration_baichuan.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\n\nlogger = logging.get_logger(__name__)\n\n\nclass BaiChuanConfig(PretrainedConfig):\n    model_type = \"baichuan\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=64000,\n        hidden_size=4096,\n        intermediate_size=11008,\n        num_hidden_layers=32,\n        num_attention_heads=32,\n        hidden_act=\"silu\",\n        max_position_embeddings=4096,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=0,\n        bos_token_id=1,\n        eos_token_id=2,\n        tie_word_embeddings=False,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n"
  },
  {
    "path": "src/resources/models/baichuan-7B/modeling_baichuan.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom .configuration_baichuan import BaiChuanConfig\nfrom transformers import PreTrainedModel, add_start_docstrings\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, \\\n    SequenceClassifierOutputWithPast\nfrom transformers.utils import logging, add_start_docstrings_to_model_forward, replace_return_docstrings\n\nimport math\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n\n\nlogger = logging.get_logger(__name__)\n\n# Copied from transformers.models.bart.modeling_bart._make_causal_mask\ndef _make_causal_mask(\n        input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0\n):\n    \"\"\"\n    Make causal mask used for bi-directional self-attention.\n    \"\"\"\n    bsz, tgt_len = input_ids_shape\n    mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)\n    mask_cond = torch.arange(mask.size(-1), device=device)\n    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)\n    mask = mask.to(dtype)\n\n    if past_key_values_length > 0:\n        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)\n    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)\n\n\n# Copied from transformers.models.bart.modeling_bart._expand_mask\ndef _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):\n    \"\"\"\n    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\n    \"\"\"\n    bsz, src_len = mask.size()\n    tgt_len = tgt_len if tgt_len is not None else src_len\n\n    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\n\n    inverted_mask = 1.0 - expanded_mask\n\n    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)\n\n\nclass RMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        RMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n\n        # convert into half-precision if necessary\n        if self.weight.dtype in [torch.float16, torch.bfloat16]:\n            hidden_states = hidden_states.to(self.weight.dtype)\n\n        return self.weight * hidden_states\n\n\nclass RotaryEmbedding(torch.nn.Module):\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):\n        super().__init__()\n        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))\n        self.register_buffer(\"inv_freq\", inv_freq)\n\n        # Build here to make `torch.jit.trace` work.\n        self.max_seq_len_cached = max_position_embeddings\n        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos()[None, None, :, :], persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin()[None, None, :, :], persistent=False)\n\n    def forward(self, x, seq_len=None):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.\n        if seq_len > self.max_seq_len_cached:\n            self.max_seq_len_cached = seq_len\n            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)\n            freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n            # Different from paper, but it uses a different permutation in order to obtain the same calculation\n            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)\n            self.register_buffer(\"cos_cached\", emb.cos()[None, None, :, :], persistent=False)\n            self.register_buffer(\"sin_cached\", emb.sin()[None, None, :, :], persistent=False)\n        return (\n            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),\n            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),\n        )\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2:]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids):\n    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.\n    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]\n    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]\n    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]\n    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\nclass MLP(nn.Module):\n    def __init__(\n            self,\n            hidden_size: int,\n            intermediate_size: int,\n            hidden_act: str,\n    ):\n        super().__init__()\n        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)\n        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)\n        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)\n        self.act_fn = ACT2FN[hidden_act]\n\n    def forward(self, x):\n        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n\n\nclass Attention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: BaiChuanConfig):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.hidden_size // self.num_heads\n        self.max_position_embeddings = config.max_position_embeddings\n\n        if (self.head_dim * self.num_heads) != self.hidden_size:\n            raise ValueError(\n                f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}\"\n                f\" and `num_heads`: {self.num_heads}).\"\n            )\n        # self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)\n        # self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)\n        # self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)\n        self.W_pack = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)\n        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)\n        self.rotary_emb = RotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()\n\n    def forward(\n            self,\n            hidden_states: torch.Tensor,\n            attention_mask: Optional[torch.Tensor] = None,\n            position_ids: Optional[torch.LongTensor] = None,\n            past_key_value: Optional[Tuple[torch.Tensor]] = None,\n            output_attentions: bool = False,\n            use_cache: bool = False,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        bsz, q_len, _ = hidden_states.size()\n\n        proj = self.W_pack(hidden_states)\n        proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)\n        query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1,\n                                                                                         2)  # batch_size x source_len x hidden_size\n        key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1,\n                                                                                       2)  # batch_size x target_len x head_size\n        value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1,\n                                                                                         2)  # batch_size x source_len x hidden_size\n\n        # query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        # key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        # value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        if past_key_value is not None:\n            kv_seq_len += past_key_value[0].shape[-2]\n        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n        # [bsz, nh, t, hd]\n\n        if past_key_value is not None:\n            # reuse k, v, self_attention\n            key_states = torch.cat([past_key_value[0], key_states], dim=2)\n            value_states = torch.cat([past_key_value[1], value_states], dim=2)\n\n        past_key_value = (key_states, value_states) if use_cache else None\n\n        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n\n        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = attn_weights + attention_mask\n            attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2)\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)\n\n        attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n\nclass DecoderLayer(nn.Module):\n    def __init__(self, config: BaiChuanConfig):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.self_attn = Attention(config=config)\n        self.mlp = MLP(\n            hidden_size=self.hidden_size,\n            intermediate_size=config.intermediate_size,\n            hidden_act=config.hidden_act,\n        )\n        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n    def forward(\n            self,\n            hidden_states: torch.Tensor,\n            attention_mask: Optional[torch.Tensor] = None,\n            position_ids: Optional[torch.LongTensor] = None,\n            past_key_value: Optional[Tuple[torch.Tensor]] = None,\n            output_attentions: Optional[bool] = False,\n            use_cache: Optional[bool] = False,\n    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n        \"\"\"\n\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if use_cache:\n            outputs += (present_key_value,)\n\n        return outputs\n\n\nclass PreTrainedModel(PreTrainedModel):\n    config_class = BaiChuanConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"DecoderLayer\"]\n    _keys_to_ignore_on_load_unexpected = [r\"decoder\\.version\"]\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, Model):\n            module.gradient_checkpointing = value\n\n\nclass Model(PreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DecoderLayer`]\n\n    Args:\n        config: BaiChuanConfig\n    \"\"\"\n\n    def __init__(self, config: BaiChuanConfig):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n        self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])\n        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n        self.gradient_checkpointing = False\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask\n    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):\n        # create causal mask\n        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n        combined_attention_mask = None\n        if input_shape[-1] > 1:\n            combined_attention_mask = _make_causal_mask(\n                input_shape,\n                inputs_embeds.dtype,\n                device=inputs_embeds.device,\n                past_key_values_length=past_key_values_length,\n            )\n\n        if attention_mask is not None:\n            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(\n                inputs_embeds.device\n            )\n            combined_attention_mask = (\n                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask\n            )\n\n        return combined_attention_mask\n\n    def forward(\n            self,\n            input_ids: torch.LongTensor = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            position_ids: Optional[torch.LongTensor] = None,\n            past_key_values: Optional[List[torch.FloatTensor]] = None,\n            inputs_embeds: Optional[torch.FloatTensor] = None,\n            use_cache: Optional[bool] = None,\n            output_attentions: Optional[bool] = None,\n            output_hidden_states: Optional[bool] = None,\n            return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        # retrieve input_ids and inputs_embeds\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time\")\n        elif input_ids is not None:\n            batch_size, seq_length = input_ids.shape\n        elif inputs_embeds is not None:\n            batch_size, seq_length, _ = inputs_embeds.shape\n        else:\n            raise ValueError(\"You have to specify either decoder_input_ids or decoder_inputs_embeds\")\n\n        seq_length_with_past = seq_length\n        past_key_values_length = 0\n\n        if past_key_values is not None:\n            past_key_values_length = past_key_values[0][0].shape[2]\n            seq_length_with_past = seq_length_with_past + past_key_values_length\n\n        if position_ids is None:\n            device = input_ids.device if input_ids is not None else inputs_embeds.device\n            position_ids = torch.arange(\n                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device\n            )\n            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)\n        else:\n            position_ids = position_ids.view(-1, seq_length).long()\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n        # embed positions\n        if attention_mask is None:\n            attention_mask = torch.ones(\n                (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device\n            )\n        attention_mask = self._prepare_decoder_attention_mask(\n            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length\n        )\n\n        hidden_states = inputs_embeds\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                logger.warning_once(\n                    \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                )\n                use_cache = False\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        next_decoder_cache = () if use_cache else None\n\n        for idx, decoder_layer in enumerate(self.layers):\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            past_key_value = past_key_values[idx] if past_key_values is not None else None\n\n            if self.gradient_checkpointing and self.training:\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        # None for past_key_value\n                        return module(*inputs, output_attentions, None)\n\n                    return custom_forward\n\n                layer_outputs = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(decoder_layer),\n                    hidden_states,\n                    attention_mask,\n                    position_ids,\n                    None,\n                )\n            else:\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=attention_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_value,\n                    output_attentions=output_attentions,\n                    use_cache=use_cache,\n                )\n\n            hidden_states = layer_outputs[0]\n\n            if use_cache:\n                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = next_decoder_cache if use_cache else None\n        if not return_dict:\n            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n\n\nclass BaiChuanForCausalLM(PreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = Model(config)\n\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    def forward(\n            self,\n            input_ids: torch.LongTensor = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            position_ids: Optional[torch.LongTensor] = None,\n            past_key_values: Optional[List[torch.FloatTensor]] = None,\n            inputs_embeds: Optional[torch.FloatTensor] = None,\n            labels: Optional[torch.LongTensor] = None,\n            use_cache: Optional[bool] = None,\n            output_attentions: Optional[bool] = None,\n            output_hidden_states: Optional[bool] = None,\n            return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, ModelForCausalLM\n\n        >>> model = ModelForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)\n        >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)\n\n        >>> prompt = \"Hey, are you consciours? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you consciours? Can you talk to me?\\nI'm not consciours, but I can talk to you.\"\n        ```\"\"\"\n\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        hidden_states = outputs[0]\n        logits = self.lm_head(hidden_states)\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n    def prepare_inputs_for_generation(\n            self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs\n    ):\n        if past_key_values:\n            input_ids = input_ids[:, -1:]\n\n        position_ids = kwargs.get(\"position_ids\", None)\n        if attention_mask is not None and position_ids is None:\n            # create position_ids on the fly for batch generation\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            if past_key_values:\n                position_ids = position_ids[:, -1].unsqueeze(-1)\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_key_values is None:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        model_inputs.update(\n            {\n                \"position_ids\": position_ids,\n                \"past_key_values\": past_key_values,\n                \"use_cache\": kwargs.get(\"use_cache\"),\n                \"attention_mask\": attention_mask,\n            }\n        )\n        return model_inputs\n\n    @staticmethod\n    def _reorder_cache(past_key_values, beam_idx):\n        reordered_past = ()\n        for layer_past in past_key_values:\n            reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)\n        return reordered_past\n"
  },
  {
    "path": "src/resources/models/baichuan-7B/tokenization_baichuan.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nfrom shutil import copyfile\nfrom typing import Any, Dict, List, Optional, Tuple\n\nimport sentencepiece as spm\n\nfrom transformers.tokenization_utils import AddedToken, PreTrainedTokenizer\nfrom transformers.utils import logging\n\n\nlogger = logging.get_logger(__name__)\n\nVOCAB_FILES_NAMES = {\"vocab_file\": \"tokenizer.model\"}\n\nPRETRAINED_VOCAB_FILES_MAP = {\n    \"vocab_file\": {},\n    \"tokenizer_file\": {},\n}\nPRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {}\n\n\nclass BaiChuanTokenizer(PreTrainedTokenizer):\n    \"\"\"\n    Construct a BaiChuan tokenizer. Based on byte-level Byte-Pair-Encoding.\n\n    Args:\n        vocab_file (`str`):\n            Path to the vocabulary file.\n    \"\"\"\n\n    vocab_files_names = VOCAB_FILES_NAMES\n    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP\n    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES\n    model_input_names = [\"input_ids\", \"attention_mask\"]\n\n    def __init__(\n        self,\n        vocab_file,\n        unk_token=\"<unk>\",\n        bos_token=\"<s>\",\n        eos_token=\"</s>\",\n        pad_token=None,\n        sp_model_kwargs: Optional[Dict[str, Any]] = None,\n        add_bos_token=True,\n        add_eos_token=False,\n        clean_up_tokenization_spaces=False,\n        **kwargs,\n    ):\n        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs\n        bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token\n        eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token\n        unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token\n        pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token\n        super().__init__(\n            bos_token=bos_token,\n            eos_token=eos_token,\n            unk_token=unk_token,\n            pad_token=pad_token,\n            add_bos_token=add_bos_token,\n            add_eos_token=add_eos_token,\n            sp_model_kwargs=self.sp_model_kwargs,\n            clean_up_tokenization_spaces=clean_up_tokenization_spaces,\n            **kwargs,\n        )\n        self.vocab_file = vocab_file\n        self.add_bos_token = add_bos_token\n        self.add_eos_token = add_eos_token\n        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)\n        self.sp_model.Load(vocab_file)\n\n    def __getstate__(self):\n        state = self.__dict__.copy()\n        state[\"sp_model\"] = None\n        return state\n\n    def __setstate__(self, d):\n        self.__dict__ = d\n        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)\n        self.sp_model.Load(self.vocab_file)\n\n    @property\n    def vocab_size(self):\n        \"\"\"Returns vocab size\"\"\"\n        return self.sp_model.get_piece_size()\n\n    def get_vocab(self):\n        \"\"\"Returns vocab as a dict\"\"\"\n        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}\n        vocab.update(self.added_tokens_encoder)\n        return vocab\n\n    def _tokenize(self, text):\n        \"\"\"Returns a tokenized string.\"\"\"\n        return self.sp_model.encode(text, out_type=str)\n\n    def _convert_token_to_id(self, token):\n        \"\"\"Converts a token (str) in an id using the vocab.\"\"\"\n        return self.sp_model.piece_to_id(token)\n\n    def _convert_id_to_token(self, index):\n        \"\"\"Converts an index (integer) in a token (str) using the vocab.\"\"\"\n        token = self.sp_model.IdToPiece(index)\n        return token\n\n    def convert_tokens_to_string(self, tokens):\n        \"\"\"Converts a sequence of tokens (string) in a single string.\"\"\"\n        current_sub_tokens = []\n        out_string = \"\"\n        prev_is_special = False\n        for i, token in enumerate(tokens):\n            # make sure that special tokens are not decoded using sentencepiece model\n            if token in self.all_special_tokens:\n                if not prev_is_special and i != 0:\n                    out_string += \" \"\n                out_string += self.sp_model.decode(current_sub_tokens) + token\n                prev_is_special = True\n                current_sub_tokens = []\n            else:\n                current_sub_tokens.append(token)\n                prev_is_special = False\n        out_string += self.sp_model.decode(current_sub_tokens)\n        return out_string\n\n    def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:\n        \"\"\"\n        Save the vocabulary and special tokens file to a directory.\n\n        Args:\n            save_directory (`str`):\n                The directory in which to save the vocabulary.\n\n        Returns:\n            `Tuple(str)`: Paths to the files saved.\n        \"\"\"\n        if not os.path.isdir(save_directory):\n            logger.error(f\"Vocabulary path ({save_directory}) should be a directory\")\n            return\n        out_vocab_file = os.path.join(\n            save_directory, (filename_prefix + \"-\" if filename_prefix else \"\") + VOCAB_FILES_NAMES[\"vocab_file\"]\n        )\n\n        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):\n            copyfile(self.vocab_file, out_vocab_file)\n        elif not os.path.isfile(self.vocab_file):\n            with open(out_vocab_file, \"wb\") as fi:\n                content_spiece_model = self.sp_model.serialized_model_proto()\n                fi.write(content_spiece_model)\n\n        return (out_vocab_file,)\n\n    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):\n        bos_token_id = [self.bos_token_id] if self.add_bos_token else []\n        eos_token_id = [self.eos_token_id] if self.add_eos_token else []\n\n        output = bos_token_id + token_ids_0 + eos_token_id\n\n        if token_ids_1 is not None:\n            output = output + bos_token_id + token_ids_1 + eos_token_id\n\n        return output\n\n    def get_special_tokens_mask(\n        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False\n    ) -> List[int]:\n        \"\"\"\n        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding\n        special tokens using the tokenizer `prepare_for_model` method.\n\n        Args:\n            token_ids_0 (`List[int]`):\n                List of IDs.\n            token_ids_1 (`List[int]`, *optional*):\n                Optional second list of IDs for sequence pairs.\n            already_has_special_tokens (`bool`, *optional*, defaults to `False`):\n                Whether or not the token list is already formatted with special tokens for the model.\n\n        Returns:\n            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.\n        \"\"\"\n        if already_has_special_tokens:\n            return super().get_special_tokens_mask(\n                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True\n            )\n\n        bos_token_id = [1] if self.add_bos_token else []\n        eos_token_id = [1] if self.add_eos_token else []\n\n        if token_ids_1 is None:\n            return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id\n        return (\n            bos_token_id\n            + ([0] * len(token_ids_0))\n            + eos_token_id\n            + bos_token_id\n            + ([0] * len(token_ids_1))\n            + eos_token_id\n        )\n\n    def create_token_type_ids_from_sequences(\n        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None\n    ) -> List[int]:\n        \"\"\"\n        Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT\n        sequence pair mask has the following format:\n\n        ```\n        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1\n        | first sequence    | second sequence |\n        ```\n\n        if token_ids_1 is None, only returns the first portion of the mask (0s).\n\n        Args:\n            token_ids_0 (`List[int]`):\n                List of ids.\n            token_ids_1 (`List[int]`, *optional*):\n                Optional second list of IDs for sequence pairs.\n\n        Returns:\n            `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).\n        \"\"\"\n        bos_token_id = [self.bos_token_id] if self.add_bos_token else []\n        eos_token_id = [self.eos_token_id] if self.add_eos_token else []\n\n        output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)\n\n        if token_ids_1 is not None:\n            output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)\n\n        return output"
  },
  {
    "path": "src/resources/models/baichuan-7B/tokenizer_config.json",
    "content": "{\n  \"auto_map\": {\n    \"AutoTokenizer\": [\"tokenization_baichuan.BaiChuanTokenizer\", null]\n  },\n  \"add_bos_token\": true,\n  \"add_eos_token\": false,\n  \"bos_token\": {\n    \"__type\": \"AddedToken\",\n    \"content\": \"<s>\",\n    \"lstrip\": false,\n    \"normalized\": true,\n    \"rstrip\": false,\n    \"single_word\": false\n  },\n  \"clean_up_tokenization_spaces\": false,\n  \"eos_token\": {\n    \"__type\": \"AddedToken\",\n    \"content\": \"</s>\",\n    \"lstrip\": false,\n    \"normalized\": true,\n    \"rstrip\": false,\n    \"single_word\": false\n  },\n  \"model_max_length\": 1000000000000000019884624838656,\n  \"sp_model_kwargs\": {},\n  \"tokenizer_class\": \"BaiChuanTokenizer\",\n  \"unk_token\": {\n    \"__type\": \"AddedToken\",\n    \"content\": \"<unk>\",\n    \"lstrip\": false,\n    \"normalized\": true,\n    \"rstrip\": false,\n    \"single_word\": false\n  }\n}\n"
  },
  {
    "path": "src/resources/models/chatglm-6B/README.md",
    "content": "---\nlanguage:\n- zh\n- en\ntags:\n- glm\n- chatglm\n- chatgpt\n---\n\nLink to github: [here](https://github.com/sunzeyeah/RLHF)\n\n--- \n\n本仓库由[THUDM/chatglm-6b](https://huggingface.co/THUDM/chatglm-6b) fork而来，原仓库实现了PyTorch版本的ChatGLM模型，该模型有60亿参数量，模型权重文件以FP16格式存储。\n\n本仓库在原始代码的基础上进行了部分调整，以支持ChatGPT训练pipeline，具体实现可参考：[sunzeyeah/RLHF](https://github.com/sunzeyeah/RLHF).\n\nThis repository is forked from [THUDM/chatglm-6b](https://huggingface.co/THUDM/chatglm-6b) that contains PyTorch implementation of ChatGLM model with 6 billion parameters pretrained weights (FP16 precision).\n\nIt is slightly different from the original ChatGLM implementation to support the ChatGPT training pipeline in this github repo: [sunzeyeah/RLHF](https://github.com/sunzeyeah/RLHF).\n\n---\n\n# 介绍\nChatGLM-6B 是一个开源的、支持中英双语问答的对话语言模型，基于 [General Language Model (GLM)](https://github.com/THUDM/GLM) 架构，具有 62 亿参数。结合模型量化技术，用户可以在消费级的显卡上进行本地部署（INT4 量化级别下最低只需 6GB 显存）。ChatGLM-6B 使用了和 [ChatGLM](https://chatglm.cn) 相同的技术，针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练，辅以监督微调、反馈自助、人类反馈强化学习等技术的加持，62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。\n\nChatGLM-6B is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework, with 6.2 billion parameters. With the quantization technique, users can deploy locally on consumer-grade graphics cards (only 6GB of GPU memory is required at the INT4 quantization level). ChatGLM-6B uses technology similar to ChatGPT, optimized for Chinese QA and dialogue. The model is trained for about 1T tokens of Chinese and English corpus, supplemented by supervised fine-tuning, feedback bootstrap, and reinforcement learning wit human feedback. With only about 6.2 billion parameters, the model is able to generate answers that are in line with human preference.\n\n---\n\n# 软件依赖\n\n```shell\npip install protobuf==3.20.0 transformers==4.26.1 icetk cpm_kernels\n```\n\n---\n\n# 代码调用\n\n可以通过如下代码调用 ChatGLM-6B 模型来生成对话：\n\n```ipython\n>>> from transformers import AutoTokenizer, AutoModel\n>>> tokenizer = AutoTokenizer.from_pretrained(\"sunzeyeah/chatglm-6B\", trust_remote_code=True)\n>>> model = AutoModel.from_pretrained(\"sunzeyeah/chatglm-6B\", trust_remote_code=True).half().cuda()\n>>> response, history = model.chat(tokenizer, \"你好\", history=[])\n>>> print(response)\n你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。\n>>> response, history = model.chat(tokenizer, \"晚上睡不着应该怎么办\", history=history)\n>>> print(response)\n晚上睡不着可能会让你感到焦虑或不舒服,但以下是一些可以帮助你入睡的方法:\n\n1. 制定规律的睡眠时间表:保持规律的睡眠时间表可以帮助你建立健康的睡眠习惯,使你更容易入睡。尽量在每天的相同时间上床,并在同一时间起床。\n2. 创造一个舒适的睡眠环境:确保睡眠环境舒适,安静,黑暗且温度适宜。可以使用舒适的床上用品,并保持房间通风。\n3. 放松身心:在睡前做些放松的活动,例如泡个热水澡,听些轻柔的音乐,阅读一些有趣的书籍等,有助于缓解紧张和焦虑,使你更容易入睡。\n4. 避免饮用含有咖啡因的饮料:咖啡因是一种刺激性物质,会影响你的睡眠质量。尽量避免在睡前饮用含有咖啡因的饮料,例如咖啡,茶和可乐。\n5. 避免在床上做与睡眠无关的事情:在床上做些与睡眠无关的事情,例如看电影,玩游戏或工作等,可能会干扰你的睡眠。\n6. 尝试呼吸技巧:深呼吸是一种放松技巧,可以帮助你缓解紧张和焦虑,使你更容易入睡。试着慢慢吸气,保持几秒钟,然后缓慢呼气。\n\n如果这些方法无法帮助你入睡,你可以考虑咨询医生或睡眠专家,寻求进一步的建议。\n```\n\n"
  },
  {
    "path": "src/resources/models/chatglm-6B/config.json",
    "content": "{\n  \"_name_or_path\": \"sunzeyeah/chatglm-6B\",\n  \"architectures\": [\n    \"ChatGLMModel\"\n  ],\n  \"auto_map\": {\n    \"AutoConfig\": \"configuration_chatglm.ChatGLMConfig\",\n    \"AutoModel\": \"modeling_chatglm.ChatGLMForConditionalGeneration\",\n    \"AutoModelForSeq2SeqLM\": \"modeling_chatglm.ChatGLMForConditionalGeneration\"\n  },\n  \"bos_token_id\": 150004,\n  \"eos_token_id\": 150005,\n  \"pad_token_id\": 20003,\n  \"hidden_size\": 4096,\n  \"inner_hidden_size\": 16384,\n  \"layernorm_epsilon\": 1e-05,\n  \"max_sequence_length\": 2048,\n  \"model_type\": \"chatglm\",\n  \"num_attention_heads\": 32,\n  \"num_layers\": 28,\n  \"position_encoding_2d\": true,\n  \"torch_dtype\": \"float16\",\n  \"transformers_version\": \"4.23.1\",\n  \"use_cache\": true,\n  \"vocab_size\": 150528\n}"
  },
  {
    "path": "src/resources/models/chatglm-6B/configuration_chatglm.py",
    "content": "\"\"\" ChatGLM model configuration \"\"\"\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\n\nclass ChatGLMConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`~ChatGLMModel`].\n    It is used to instantiate an ChatGLM model according to the specified arguments, defining the model\n    architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of\n    the ChatGLM-6B [THUDM/ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b) architecture.\n\n    Configuration objects inherit from  [`PretrainedConfig`] and can be used\n    to control the model outputs. Read the documentation from  [`PretrainedConfig`]\n    for more information.\n\n\n    Args:\n        vocab_size (`int`, *optional*, defaults to 150528):\n            Vocabulary size of the ChatGLM-6B model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`~ChatGLMModel`] or\n            [`~TFChatGLMModel`].\n        hidden_size (`int`, *optional*, defaults to 4096):\n            Dimension of the encoder layers and the pooler layer.\n        num_hidden_layers (`int`, *optional*, defaults to 28):\n            Number of hidden layers in the Transformer encoder.\n        num_attention_heads (`int`, *optional*, defaults to 32):\n            Number of attention heads for each attention layer in the Transformer encoder.\n        inner_hidden_size (`int`, *optional*, defaults to 16384):\n            Dimension of the \"intermediate\" (i.e., feed-forward) layer in the Transformer encoder.\n        max_sequence_length (`int`, *optional*, defaults to 512):\n            The maximum sequence length that this model might ever be used with.\n            Typically set this to something large just in case (e.g., 512 or 1024 or 2048).\n        layernorm_epsilon (`float`, *optional*, defaults to 1e-5):\n            The epsilon used by the layer normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether the model should return the last key/values attentions (not used by all models).\n        Example:\n\n    ```python\n    >>> from configuration_chatglm import ChatGLMConfig\n    >>> from modeling_chatglm import ChatGLMModel\n\n    >>> # Initializing a ChatGLM-6B THUDM/ChatGLM-6B style configuration\n    >>> configuration = ChatGLMConfig()\n\n    >>> # Initializing a model from the THUDM/ChatGLM-6B style configuration\n    >>> model = ChatGLMModel(configuration)\n\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\n\"\"\"\n    model_type = \"chatglm\"\n\n    def __init__(\n            self,\n            vocab_size=150528,\n            hidden_size=4096,\n            num_layers=28,\n            num_attention_heads=32,\n            layernorm_epsilon=1e-5,\n            use_cache=False,\n            bos_token_id=150004,\n            eos_token_id=150005,\n            pad_token_id=0,\n            max_sequence_length=2048,\n            inner_hidden_size=16384,\n            position_encoding_2d=True,\n            quantization_bit=0,\n            pre_seq_len=None,\n            prefix_projection=False,\n            **kwargs\n    ):\n        self.num_layers = num_layers\n        self.vocab_size = vocab_size\n        self.hidden_size = hidden_size\n        self.num_attention_heads = num_attention_heads\n        self.max_sequence_length = max_sequence_length\n        self.layernorm_epsilon = layernorm_epsilon\n        self.inner_hidden_size = inner_hidden_size\n        self.use_cache = use_cache\n        self.bos_token_id = bos_token_id\n        self.eos_token_id = eos_token_id\n        self.pad_token_id = pad_token_id\n        self.position_encoding_2d = position_encoding_2d\n        self.quantization_bit = quantization_bit\n        self.pre_seq_len = pre_seq_len\n        self.prefix_projection = prefix_projection\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            **kwargs\n        )"
  },
  {
    "path": "src/resources/models/chatglm-6B/modeling_chatglm.py",
    "content": "\"\"\" PyTorch ChatGLM model. \"\"\"\n\nimport math\nimport copy\nimport os\nimport warnings\nimport re\nimport sys\n\nimport torch\nimport torch.utils.checkpoint\nimport torch.nn.functional as F\nfrom torch import nn\nfrom torch.nn import CrossEntropyLoss, LayerNorm\n# from torch.nn.utils import skip_init\nfrom typing import Optional, Tuple, Union, List, Callable\n\nfrom transformers.utils import (\n    add_code_sample_docstrings,\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n)\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n    BaseModelOutputWithPastAndCrossAttentions,\n)\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.utils import logging\nfrom transformers.generation.logits_process import LogitsProcessor\nfrom transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig\n\nfrom .configuration_chatglm import ChatGLMConfig\n\n# flags required to enable jit fusion kernels\n\nif sys.platform != 'darwin':\n    torch._C._jit_set_profiling_mode(False)\n    torch._C._jit_set_profiling_executor(False)\n    torch._C._jit_override_can_fuse_on_cpu(True)\n    torch._C._jit_override_can_fuse_on_gpu(True)\n\nlogger = logging.get_logger(__name__)\n\n_CHECKPOINT_FOR_DOC = \"THUDM/ChatGLM-6B\"\n_CONFIG_FOR_DOC = \"ChatGLM6BConfig\"\n\nCHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [\n    \"THUDM/chatglm-6b\",\n    # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm\n]\n\n\nclass InvalidScoreLogitsProcessor(LogitsProcessor):\n    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:\n        if torch.isnan(scores).any() or torch.isinf(scores).any():\n            scores.zero_()\n            scores[..., 20005] = 5e4\n        return scores\n\n\ndef load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):\n    \"\"\"Load tf checkpoints in a pytorch model.\"\"\"\n    try:\n        import re\n\n        import numpy as np\n        import tensorflow as tf\n    except ImportError:\n        logger.error(\n            \"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see \"\n            \"https://www.tensorflow.org/install/ for installation instructions.\"\n        )\n        raise\n    tf_path = os.path.abspath(tf_checkpoint_path)\n    logger.info(f\"Converting TensorFlow checkpoint from {tf_path}\")\n    # Load weights from TF model\n    init_vars = tf.train.list_variables(tf_path)\n    names = []\n    arrays = []\n    for name, shape in init_vars:\n        logger.info(f\"Loading TF weight {name} with shape {shape}\")\n        array = tf.train.load_variable(tf_path, name)\n        names.append(name)\n        arrays.append(array)\n\n    for name, array in zip(names, arrays):\n        name = name.split(\"/\")\n        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v\n        # which are not required for using pretrained model\n        if any(\n                n in [\"adam_v\", \"adam_m\", \"AdamWeightDecayOptimizer\", \"AdamWeightDecayOptimizer_1\", \"global_step\"]\n                for n in name\n        ):\n            logger.info(f\"Skipping {'/'.join(name)}\")\n            continue\n        pointer = model\n        for m_name in name:\n            if re.fullmatch(r\"[A-Za-z]+_\\d+\", m_name):\n                scope_names = re.split(r\"_(\\d+)\", m_name)\n            else:\n                scope_names = [m_name]\n            if scope_names[0] == \"kernel\" or scope_names[0] == \"gamma\":\n                pointer = getattr(pointer, \"weight\")\n            elif scope_names[0] == \"output_bias\" or scope_names[0] == \"beta\":\n                pointer = getattr(pointer, \"bias\")\n            elif scope_names[0] == \"output_weights\":\n                pointer = getattr(pointer, \"weight\")\n            elif scope_names[0] == \"squad\":\n                pointer = getattr(pointer, \"classifier\")\n            else:\n                try:\n                    pointer = getattr(pointer, scope_names[0])\n                except AttributeError:\n                    logger.info(f\"Skipping {'/'.join(name)}\")\n                    continue\n            if len(scope_names) >= 2:\n                num = int(scope_names[1])\n                pointer = pointer[num]\n        if m_name[-11:] == \"_embeddings\":\n            pointer = getattr(pointer, \"weight\")\n        elif m_name == \"kernel\":\n            array = np.transpose(array)\n        try:\n            assert (\n                    pointer.shape == array.shape\n            ), f\"Pointer shape {pointer.shape} and array shape {array.shape} mismatched\"\n        except AssertionError as e:\n            e.args += (pointer.shape, array.shape)\n            raise\n        logger.info(f\"Initialize PyTorch weight {name}\")\n        pointer.data = torch.from_numpy(array)\n    return model\n\n\nclass PrefixEncoder(torch.nn.Module):\n    \"\"\"\n    The torch.nn model to encode the prefix\n    Input shape: (batch-size, prefix-length)\n    Output shape: (batch-size, prefix-length, 2*layers*hidden)\n    \"\"\"\n\n    def __init__(self, config):\n        super().__init__()\n        self.prefix_projection = config.prefix_projection\n        if self.prefix_projection:\n            # Use a two-layer MLP to encode the prefix\n            self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)\n            self.trans = torch.nn.Sequential(\n                torch.nn.Linear(config.hidden_size, config.hidden_size),\n                torch.nn.Tanh(),\n                torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)\n            )\n        else:\n            self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)\n\n    def forward(self, prefix: torch.Tensor):\n        if self.prefix_projection:\n            prefix_tokens = self.embedding(prefix)\n            past_key_values = self.trans(prefix_tokens)\n        else:\n            past_key_values = self.embedding(prefix)\n        return past_key_values\n\n\n@torch.jit.script\ndef gelu_impl(x):\n    \"\"\"OpenAI's gelu implementation.\"\"\"\n    return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *\n                                       (1.0 + 0.044715 * x * x)))\n\n\ndef gelu(x):\n    return gelu_impl(x)\n\n\nclass RotaryEmbedding(torch.nn.Module):\n    def __init__(self, dim, base=10000, precision=torch.half, learnable=False):\n        super().__init__()\n        inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))\n        inv_freq = inv_freq.half()\n        self.learnable = learnable\n        if learnable:\n            self.inv_freq = torch.nn.Parameter(inv_freq)\n            self.max_seq_len_cached = None\n        else:\n            self.register_buffer('inv_freq', inv_freq)\n            self.max_seq_len_cached = None\n            self.cos_cached = None\n            self.sin_cached = None\n        self.precision = precision\n\n    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,\n                              error_msgs):\n        pass\n\n    def forward(self, x, seq_dim=1, seq_len=None):\n        if seq_len is None:\n            seq_len = x.shape[seq_dim]\n        if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):\n            self.max_seq_len_cached = None if self.learnable else seq_len\n            t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)\n            freqs = torch.einsum('i,j->ij', t, self.inv_freq)\n            # Different from paper, but it uses a different permutation in order to obtain the same calculation\n            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)\n            if self.precision == torch.bfloat16:\n                emb = emb.float()\n\n            # [sx, 1 (b * np), hn]\n            cos_cached = emb.cos()[:, None, :]\n            sin_cached = emb.sin()[:, None, :]\n            if self.precision == torch.bfloat16:\n                cos_cached = cos_cached.bfloat16()\n                sin_cached = sin_cached.bfloat16()\n            if self.learnable:\n                return cos_cached, sin_cached\n            self.cos_cached, self.sin_cached = cos_cached, sin_cached\n        return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]\n\n    def _apply(self, fn):\n        if self.cos_cached is not None:\n            self.cos_cached = fn(self.cos_cached)\n        if self.sin_cached is not None:\n            self.sin_cached = fn(self.sin_cached)\n        return super()._apply(fn)\n\n\ndef rotate_half(x):\n    x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]\n    return torch.cat((-x2, x1), dim=x1.ndim - 1)  # dim=-1 triggers a bug in earlier torch versions\n\n\n@torch.jit.script\ndef apply_rotary_pos_emb_index(q, k, cos, sin, position_id):\n    # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]\n    cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \\\n        F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)\n    q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)\n    return q, k\n\n\ndef attention_fn(\n        self,\n        query_layer,\n        key_layer,\n        value_layer,\n        attention_mask,\n        hidden_size_per_partition,\n        layer_id,\n        layer_past=None,\n        scaling_attention_score=True,\n        use_cache=False,\n):\n    if layer_past is not None:\n        past_key, past_value = layer_past[0], layer_past[1]\n        key_layer = torch.cat((past_key, key_layer), dim=0)\n        value_layer = torch.cat((past_value, value_layer), dim=0)\n\n    # seqlen, batch, num_attention_heads, hidden_size_per_attention_head\n    seq_len, b, nh, hidden_size = key_layer.shape\n\n    if use_cache:\n        present = (key_layer, value_layer)\n    else:\n        present = None\n\n    query_key_layer_scaling_coeff = float(layer_id + 1)\n    if scaling_attention_score:\n        query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff)\n\n    # ===================================\n    # Raw attention scores. [b, np, s, s]\n    # ===================================\n\n    # [b, np, sq, sk]\n    output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))\n\n    # [sq, b, np, hn] -> [sq, b * np, hn]\n    query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)\n    # [sk, b, np, hn] -> [sk, b * np, hn]\n    key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)\n\n    matmul_result = torch.empty(\n        output_size[0] * output_size[1],\n        output_size[2],\n        output_size[3],\n        dtype=query_layer.dtype,\n        device=query_layer.device,\n    )\n\n    matmul_result = torch.baddbmm(\n        matmul_result,\n        query_layer.transpose(0, 1),  # [b * np, sq, hn]\n        key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]\n        beta=0.0,\n        alpha=1.0,\n    )\n\n    # change view to [b, np, sq, sk]\n    attention_scores = matmul_result.view(*output_size)\n\n    if self.scale_mask_softmax:\n        self.scale_mask_softmax.scale = query_key_layer_scaling_coeff\n        attention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous())\n    else:\n        if not (attention_mask == 0).all():\n            # if auto-regressive, skip\n            attention_scores.masked_fill_(attention_mask, -10000.0)\n        dtype = attention_scores.dtype\n        attention_scores = attention_scores.float()\n        attention_scores = attention_scores * query_key_layer_scaling_coeff\n\n        attention_probs = F.softmax(attention_scores, dim=-1)\n\n        attention_probs = attention_probs.type(dtype)\n\n    # =========================\n    # Context layer. [sq, b, hp]\n    # =========================\n\n    # value_layer -> context layer.\n    # [sk, b, np, hn] --> [b, np, sq, hn]\n\n    # context layer shape: [b, np, sq, hn]\n    output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))\n\n    # change view [sk, b * np, hn]\n    value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)\n\n    # change view [b * np, sq, sk]\n    attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)\n\n    # matmul: [b * np, sq, hn]\n    context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))\n\n    # change view [b, np, sq, hn]\n    context_layer = context_layer.view(*output_size)\n\n    # [b, np, sq, hn] --> [sq, b, np, hn]\n    context_layer = context_layer.permute(2, 0, 1, 3).contiguous()\n\n    # [sq, b, np, hn] --> [sq, b, hp]\n    new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,)\n    context_layer = context_layer.view(*new_context_layer_shape)\n\n    outputs = (context_layer, present, attention_probs)\n\n    return outputs\n\n\nclass SelfAttention(torch.nn.Module):\n    def __init__(self, hidden_size, num_attention_heads,\n                 layer_id, hidden_size_per_attention_head=None, bias=True,\n                 params_dtype=torch.float, position_encoding_2d=True):\n        super(SelfAttention, self).__init__()\n\n        self.layer_id = layer_id\n        self.hidden_size = hidden_size\n        self.hidden_size_per_partition = hidden_size\n        self.num_attention_heads = num_attention_heads\n        self.num_attention_heads_per_partition = num_attention_heads\n        self.position_encoding_2d = position_encoding_2d\n        self.rotary_emb = RotaryEmbedding(\n            self.hidden_size // (self.num_attention_heads * 2)\n            if position_encoding_2d\n            else self.hidden_size // self.num_attention_heads,\n            base=10000,\n            precision=torch.half,\n            learnable=False,\n        )\n\n        self.scale_mask_softmax = None\n\n        if hidden_size_per_attention_head is None:\n            self.hidden_size_per_attention_head = hidden_size // num_attention_heads\n        else:\n            self.hidden_size_per_attention_head = hidden_size_per_attention_head\n\n        self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head\n\n        # Strided linear layer.\n        # self.query_key_value = skip_init(\n        self.query_key_value = torch.nn.Linear(\n            hidden_size,\n            3 * self.inner_hidden_size,\n            bias=bias,\n            dtype=params_dtype,\n        )\n\n        # self.dense = skip_init(\n        self.dense = torch.nn.Linear(\n            self.inner_hidden_size,\n            hidden_size,\n            bias=bias,\n            dtype=params_dtype,\n        )\n\n    @staticmethod\n    def attention_mask_func(attention_scores, attention_mask):\n        attention_scores.masked_fill_(attention_mask, -10000.0)\n        return attention_scores\n\n    def split_tensor_along_last_dim(self, tensor, num_partitions,\n                                    contiguous_split_chunks=False):\n        \"\"\"Split a tensor along its last dimension.\n        Arguments:\n            tensor: input tensor.\n            num_partitions: number of partitions to split the tensor\n            contiguous_split_chunks: If True, make each chunk contiguous\n                                    in memory.\n        \"\"\"\n        # Get the size and dimension.\n        last_dim = tensor.dim() - 1\n        last_dim_size = tensor.size()[last_dim] // num_partitions\n        # Split.\n        tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)\n        # Note: torch.split does not create contiguous tensors by default.\n        if contiguous_split_chunks:\n            return tuple(chunk.contiguous() for chunk in tensor_list)\n\n        return tensor_list\n\n    def forward(\n            self,\n            hidden_states: torch.Tensor,\n            position_ids,\n            attention_mask: torch.Tensor,\n            layer_id,\n            layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n            use_cache: bool = False,\n            output_attentions: bool = False,\n    ):\n        \"\"\"\n        hidden_states: [seq_len, batch, hidden_size]\n        attention_mask: [(1, 1), seq_len, seq_len]\n        \"\"\"\n\n        # [seq_len, batch, 3 * hidden_size]\n        mixed_raw_layer = self.query_key_value(hidden_states)\n\n        # [seq_len, batch, 3 * hidden_size] --> [seq_len, batch, num_attention_heads, 3 * hidden_size_per_attention_head]\n        new_tensor_shape = mixed_raw_layer.size()[:-1] + (\n            self.num_attention_heads_per_partition,\n            3 * self.hidden_size_per_attention_head,\n        )\n        mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape)\n\n        # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]\n        (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3)\n\n        if self.position_encoding_2d:\n            q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1))\n            k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))\n            cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1)\n            position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \\\n                position_ids[:, 1, :].transpose(0, 1).contiguous()\n            q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids)\n            q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids)\n            query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))\n            key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1))\n        else:\n            position_ids = position_ids.transpose(0, 1)\n            cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1)\n            # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]\n            query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids)\n\n        # [seq_len, batch, hidden_size]\n        context_layer, present, attention_probs = attention_fn(\n            self=self,\n            query_layer=query_layer,\n            key_layer=key_layer,\n            value_layer=value_layer,\n            attention_mask=attention_mask,\n            hidden_size_per_partition=self.hidden_size_per_partition,\n            layer_id=layer_id,\n            layer_past=layer_past,\n            use_cache=use_cache\n        )\n\n        output = self.dense(context_layer)\n\n        outputs = (output, present)\n\n        if output_attentions:\n            outputs += (attention_probs,)\n\n        return outputs  # output, present, attention_probs\n\n\nclass GEGLU(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.activation_fn = F.gelu\n\n    def forward(self, x):\n        # dim=-1 breaks in jit for pt<1.10\n        x1, x2 = x.chunk(2, dim=(x.ndim - 1))\n        return x1 * self.activation_fn(x2)\n\n\nclass GLU(torch.nn.Module):\n    def __init__(self, hidden_size, inner_hidden_size=None,\n                 layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float):\n        super(GLU, self).__init__()\n        self.layer_id = layer_id\n        self.activation_func = activation_func\n\n        # Project to 4h.\n        self.hidden_size = hidden_size\n        if inner_hidden_size is None:\n            inner_hidden_size = 4 * hidden_size\n        self.inner_hidden_size = inner_hidden_size\n        # self.dense_h_to_4h = skip_init(\n        self.dense_h_to_4h = torch.nn.Linear(\n            self.hidden_size,\n            self.inner_hidden_size,\n            bias=bias,\n            dtype=params_dtype,\n        )\n        # Project back to h.\n        # self.dense_4h_to_h = skip_init(\n        self.dense_4h_to_h = torch.nn.Linear(\n            self.inner_hidden_size,\n            self.hidden_size,\n            bias=bias,\n            dtype=params_dtype,\n        )\n\n    def forward(self, hidden_states):\n        \"\"\"\n        hidden_states: [seq_len, batch, hidden_size]\n        \"\"\"\n\n        # [seq_len, batch, inner_hidden_size]\n        intermediate_parallel = self.dense_h_to_4h(hidden_states)\n\n        intermediate_parallel = self.activation_func(intermediate_parallel)\n\n        output = self.dense_4h_to_h(intermediate_parallel)\n\n        return output\n\n\nclass GLMBlock(torch.nn.Module):\n    def __init__(\n            self,\n            hidden_size,\n            num_attention_heads,\n            layernorm_epsilon,\n            layer_id,\n            inner_hidden_size=None,\n            hidden_size_per_attention_head=None,\n            layernorm=LayerNorm,\n            use_bias=True,\n            params_dtype=torch.float,\n            num_layers=28,\n            position_encoding_2d=True\n    ):\n        super(GLMBlock, self).__init__()\n        # Set output layer initialization if not provided.\n\n        self.layer_id = layer_id\n\n        # Layernorm on the input data.\n        self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)\n\n        self.position_encoding_2d = position_encoding_2d\n\n        # Self attention.\n        self.attention = SelfAttention(\n            hidden_size,\n            num_attention_heads,\n            layer_id,\n            hidden_size_per_attention_head=hidden_size_per_attention_head,\n            bias=use_bias,\n            params_dtype=params_dtype,\n            position_encoding_2d=self.position_encoding_2d\n        )\n\n        # Layernorm on the input data.\n        self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)\n\n        self.num_layers = num_layers\n\n        # GLU\n        self.mlp = GLU(\n            hidden_size,\n            inner_hidden_size=inner_hidden_size,\n            bias=use_bias,\n            layer_id=layer_id,\n            params_dtype=params_dtype,\n        )\n\n    def forward(\n            self,\n            hidden_states: torch.Tensor,\n            position_ids,\n            attention_mask: torch.Tensor,\n            layer_id,\n            layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n            use_cache: bool = False,\n            output_attentions: bool = False,\n    ):\n        \"\"\"\n        hidden_states: [seq_len, batch, hidden_size]\n        attention_mask: [(1, 1), seq_len, seq_len]\n        \"\"\"\n\n        # Layer norm at the begining of the transformer layer.\n        # [seq_len, batch, hidden_size]\n        attention_input = self.input_layernorm(hidden_states)\n\n        # Self attention.\n        attention_outputs = self.attention(\n            attention_input,\n            position_ids,\n            attention_mask=attention_mask,\n            layer_id=layer_id,\n            layer_past=layer_past,\n            use_cache=use_cache,\n            output_attentions=output_attentions\n        )\n\n        attention_output = attention_outputs[0]\n\n        outputs = attention_outputs[1:]\n\n        # Residual connection.\n        alpha = (2 * self.num_layers) ** 0.5\n        hidden_states = attention_input * alpha + attention_output\n\n        mlp_input = self.post_attention_layernorm(hidden_states)\n\n        # MLP.\n        mlp_output = self.mlp(mlp_input)\n\n        # Second residual connection.\n        output = mlp_input * alpha + mlp_output\n\n        if use_cache:\n            outputs = (output,) + outputs\n        else:\n            outputs = (output,) + outputs[1:]\n\n        return outputs  # hidden_states, present, attentions\n\n\nclass ChatGLMPreTrainedModel(PreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and\n    a simple interface for downloading and loading pretrained models.\n    \"\"\"\n\n    is_parallelizable = False\n    supports_gradient_checkpointing = True\n    config_class = ChatGLMConfig\n    base_model_prefix = \"transformer\"\n    _no_split_modules = [\"GLMBlock\"]\n\n    def __init__(self, *inputs, **kwargs):\n        super().__init__(*inputs, **kwargs)\n\n    def _init_weights(self, module: nn.Module):\n        \"\"\"Initialize the weights.\"\"\"\n        return\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, ChatGLMModel):\n            module.gradient_checkpointing = value\n\n\nCHATGLM_6B_START_DOCSTRING = r\"\"\"\n    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general\n    usage and behavior.\n\n    Parameters:\n        config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model.\n            Initializing with a config file does not load the weights associated with the model, only the configuration.\n            Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\nCHATGLM_6B_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `({0})`):\n            Indices of input sequence tokens in the vocabulary.\n\n            Indices can be obtained using [`ChatGLM6BTokenizer`].\n            See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):\n            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`:\n\n            - 0 corresponds to a *sentence A* token,\n            - 1 corresponds to a *sentence B* token.\n\n            [What are token type IDs?](../glossary#token-type-ids)\n        position_ids (`torch.LongTensor` of shape `({0})`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings.\n            Selected in the range `[0, config.max_position_embeddings - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):\n            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n\n        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.\n            This is useful if you want more control over how to convert *input_ids* indices into associated vectors\n            than the model's internal embedding lookup matrix.\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.\",\n    CHATGLM_6B_START_DOCSTRING,\n)\nclass ChatGLMModel(ChatGLMPreTrainedModel):\n    \"\"\"\n\n    The model can behave as an encoder (with only self-attention) as well\n    as a decoder, in which case a layer of cross-attention is added between\n    the self-attention layers, following the architecture described in [Attention is\n    all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani,\n    Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.\n\n    To behave as an decoder the model needs to be initialized with the\n    `is_decoder` argument of the configuration set to `True`.\n    To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder`\n    argument and `add_cross_attention` set to `True`; an\n    `encoder_hidden_states` is then expected as an input to the forward pass.\n    \"\"\"\n\n    def __init__(self, config: ChatGLMConfig):\n        super().__init__(config)\n\n        # recording parameters\n        self.max_sequence_length = config.max_sequence_length\n        self.hidden_size = config.hidden_size\n        self.params_dtype = torch.half\n        self.num_attention_heads = config.num_attention_heads\n        self.vocab_size = config.vocab_size\n        self.num_layers = config.num_layers\n        self.layernorm_epsilon = config.layernorm_epsilon\n        self.inner_hidden_size = config.inner_hidden_size\n        self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads\n        self.position_encoding_2d = config.position_encoding_2d\n        self.pre_seq_len = config.pre_seq_len\n        self.prefix_projection = config.prefix_projection\n\n        # self.word_embeddings = skip_init(\n        self.word_embeddings = torch.nn.Embedding(\n            num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,\n            dtype=self.params_dtype\n        )\n        self.gradient_checkpointing = False\n\n        def get_layer(layer_id):\n            return GLMBlock(\n                self.hidden_size,\n                self.num_attention_heads,\n                self.layernorm_epsilon,\n                layer_id,\n                inner_hidden_size=self.inner_hidden_size,\n                hidden_size_per_attention_head=self.hidden_size_per_attention_head,\n                layernorm=LayerNorm,\n                use_bias=True,\n                params_dtype=self.params_dtype,\n                position_encoding_2d=self.position_encoding_2d,\n            )\n\n        self.layers = torch.nn.ModuleList(\n            [get_layer(layer_id) for layer_id in range(self.num_layers)]\n        )\n\n        # Final layer norm before output.\n        self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon)\n\n        if self.pre_seq_len is not None:\n            for param in self.parameters():\n                param.requires_grad = False\n            self.prefix_tokens = torch.arange(self.pre_seq_len).long()\n            self.prefix_encoder = PrefixEncoder(config)\n            self.dropout = torch.nn.Dropout(0.1)\n\n            # total_params = sum(p.numel() for p in self.parameters())\n            # trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)\n            # print(\"Using p-tuning v2: # trainable_params = {} / {}\".format(trainable_params, total_params))\n\n    def get_input_embeddings(self):\n        return self.word_embeddings\n\n    def set_input_embeddings(self, new_embeddings: torch.Tensor):\n        self.word_embeddings = new_embeddings\n\n    def get_prompt(self, batch_size, device, dtype=torch.half):\n        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)\n        past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)\n        past_key_values = past_key_values.view(\n            batch_size,\n            self.pre_seq_len,\n            self.num_layers * 2,\n            self.num_attention_heads,\n            self.hidden_size // self.num_attention_heads\n        )\n        # seq_len, b, nh, hidden_size\n        past_key_values = self.dropout(past_key_values)\n        past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)\n        # past_key_values = [(v[0], v[1]) for v in past_key_values]\n        return past_key_values\n\n    def get_masks(self, input_ids, device):\n        batch_size, seq_length = input_ids.shape\n        context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]\n        attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)\n        attention_mask.tril_()\n        for i, context_length in enumerate(context_lengths):\n            attention_mask[i, :, :context_length] = 1\n        attention_mask.unsqueeze_(1)\n        attention_mask = (attention_mask < 0.5).bool()\n\n        return attention_mask\n\n    def get_position_ids(self, input_ids, mask_positions, device, gmask=False):\n        batch_size, seq_length = input_ids.shape\n        context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]\n        if self.position_encoding_2d:\n            position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)\n            if not gmask:\n                for i, context_length in enumerate(context_lengths):\n                    position_ids[i, context_length:] = mask_positions[i]\n            block_position_ids = [torch.cat((\n                torch.zeros(context_length, dtype=torch.long, device=device),\n                torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1\n            )) for context_length in context_lengths]\n            block_position_ids = torch.stack(block_position_ids, dim=0)\n            position_ids = torch.stack((position_ids, block_position_ids), dim=1)\n        else:\n            position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)\n            if not gmask:\n                for i, context_length in enumerate(context_lengths):\n                    position_ids[context_length:] = mask_positions[i]\n\n        return position_ids\n\n    @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format(\"batch_size, sequence_length\"))\n    @add_code_sample_docstrings(\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=BaseModelOutputWithPastAndCrossAttentions,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n            self,\n            input_ids: Optional[torch.LongTensor] = None,\n            position_ids: Optional[torch.LongTensor] = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,\n            inputs_embeds: Optional[torch.LongTensor] = None,\n            use_cache: Optional[bool] = None,\n            output_attentions: Optional[bool] = None,\n            output_hidden_states: Optional[bool] = None,\n            return_dict: Optional[bool] = None,\n    ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]:\n\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                # logger.warning_once(\n                #     \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                # )\n                use_cache = False\n\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\"You cannot specify both input_ids and inputs_embeds at the same time\")\n        elif input_ids is not None:\n            batch_size, seq_length = input_ids.shape[:2]\n        elif inputs_embeds is not None:\n            batch_size, seq_length, _ = inputs_embeds.shape[:2]\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n\n        if inputs_embeds is None:\n            inputs_embeds = self.word_embeddings(input_ids)\n\n        if past_key_values is None:\n            if self.pre_seq_len is not None:\n                past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device,\n                                                  dtype=inputs_embeds.dtype)\n            else:\n                past_key_values = tuple([None] * len(self.layers))\n\n            if attention_mask is None:\n                attention_mask = self.get_masks(\n                    input_ids,\n                    device=input_ids.device\n                )\n\n            if self.pre_seq_len is not None:\n                prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(\n                    attention_mask.device)\n                prefix_attention_mask = (prefix_attention_mask < 0.5).bool()\n                attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)\n\n            if position_ids is None:\n                MASK, gMASK = 150000, 150001\n                mask_token = MASK if MASK in input_ids else gMASK\n                use_gmask = False if MASK in input_ids else gMASK\n\n                mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]\n                position_ids = self.get_position_ids(\n                    input_ids,\n                    mask_positions=mask_positions,\n                    device=input_ids.device,\n                    gmask=use_gmask\n                )\n\n        # [seq_len, batch, hidden_size]\n        hidden_states = inputs_embeds.transpose(0, 1)\n\n        presents = () if use_cache else None\n        all_self_attentions = () if output_attentions else None\n        all_hidden_states = () if output_hidden_states else None\n\n        if attention_mask is None:\n            attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()\n\n        else:\n            attention_mask = attention_mask.to(input_ids.device)\n\n        for i, layer in enumerate(self.layers):\n\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_states,)\n            layer_past = past_key_values[i]\n\n            if self.gradient_checkpointing and self.training:\n                layer_ret = torch.utils.checkpoint.checkpoint(\n                    layer,\n                    hidden_states,\n                    position_ids,\n                    attention_mask,\n                    torch.tensor(i),\n                    layer_past,\n                    use_cache,\n                    output_attentions\n                )\n            else:\n                layer_ret = layer(\n                    hidden_states,\n                    position_ids=position_ids,\n                    attention_mask=attention_mask,\n                    layer_id=torch.tensor(i),\n                    layer_past=layer_past,\n                    use_cache=use_cache,\n                    output_attentions=output_attentions\n                )\n\n            hidden_states = layer_ret[0]\n\n            if use_cache:\n                presents = presents + (layer_ret[1],)\n\n            if output_attentions:\n                all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],)\n\n        # Final layer norm.\n        hidden_states = self.final_layernorm(hidden_states)\n\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_states,)\n\n        if not return_dict:\n            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)\n\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=presents,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attentions,\n        )\n\n\nclass ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):\n    def __init__(self, config: ChatGLMConfig):\n        super().__init__(config)\n\n        # self.hidden_size = config.hidden_size\n        # self.params_dtype = torch.half\n        # self.vocab_size = config.vocab_size\n        self.max_sequence_length = config.max_sequence_length\n\n        self.position_encoding_2d = config.position_encoding_2d\n\n        self.transformer = ChatGLMModel(config)\n\n        # self.lm_head = skip_init(\n        self.lm_head = nn.Linear(\n            config.hidden_size,\n            config.vocab_size,\n            bias=False,\n            dtype=torch.half\n        )\n\n        self.config = config\n\n        self.quantized = False\n\n        if self.config.quantization_bit:\n            self.quantize(self.config.quantization_bit, empty_init=True)\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def get_masks_and_position_ids(self, input_ids, mask_positions, device, gmask=False):\n        batch_size, seq_length = input_ids.shape\n        context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]\n        attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)\n        attention_mask.tril_()\n        for i, context_length in enumerate(context_lengths):\n            attention_mask[i, :, :context_length] = 1\n        attention_mask.unsqueeze_(1)\n        attention_mask = (attention_mask < 0.5).bool()\n\n        batch_size, seq_length = input_ids.shape\n        context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]\n        if self.position_encoding_2d:\n            position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)\n            if not gmask:\n                for i, context_length in enumerate(context_lengths):\n                    position_ids[i, context_length:] = mask_positions[i]\n            block_position_ids = [torch.cat((\n                torch.zeros(context_length, dtype=torch.long, device=device),\n                torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1\n            )) for context_length in context_lengths]\n            block_position_ids = torch.stack(block_position_ids, dim=0)\n            position_ids = torch.stack((position_ids, block_position_ids), dim=1)\n        else:\n            position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)\n            if not gmask:\n                for i, context_length in enumerate(context_lengths):\n                    position_ids[context_length:] = mask_positions[i]\n\n        return attention_mask, position_ids\n\n    def prepare_inputs_for_generation(\n            self,\n            input_ids: torch.LongTensor,\n            past: Optional[torch.Tensor] = None,\n            past_key_values: Optional[torch.Tensor] = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            **kwargs\n    ) -> dict:\n        batch_size, seq_length = input_ids.shape\n        MASK, gMASK = 150000, 150001\n        mask_token = MASK if MASK in input_ids else gMASK\n        use_gmask = False if MASK in input_ids else gMASK\n        seqs = input_ids.tolist()\n        mask_positions = [seq.index(mask_token) for seq in seqs]\n\n        # only last token for input_ids if past is not None\n        if past is not None or past_key_values is not None:\n            context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]\n            last_token = input_ids[:, -1].unsqueeze(-1)\n            if self.position_encoding_2d:\n                position_ids = torch.tensor(\n                    [[mask_position, seq_length - context_length] for mask_position, context_length in\n                     zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)\n            else:\n                position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,\n                                            device=input_ids.device).unsqueeze(-1)\n\n            if past is None:\n                past = past_key_values\n            return {\n                \"input_ids\": last_token,\n                \"past_key_values\": past,\n                \"position_ids\": position_ids,\n            }\n        else:\n            attention_mask, position_ids = self.get_masks_and_position_ids(\n                input_ids,\n                mask_positions=mask_positions,\n                device=input_ids.device,\n                gmask=use_gmask\n            )\n\n            return {\n                \"input_ids\": input_ids,\n                \"past_key_values\": past,\n                \"position_ids\": position_ids,\n                \"attention_mask\": attention_mask\n            }\n\n    def forward(\n            self,\n            input_ids: Optional[torch.Tensor] = None,\n            position_ids: Optional[torch.Tensor] = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            past_key_values: Optional[Tuple[torch.FloatTensor]] = None,\n            inputs_embeds: Optional[torch.Tensor] = None,\n            labels: Optional[torch.Tensor] = None,\n            use_cache: Optional[bool] = None,\n            output_attentions: Optional[bool] = None,\n            output_hidden_states: Optional[bool] = None,\n            return_dict: Optional[bool] = None,\n    ):\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        transformer_outputs = self.transformer(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            attention_mask=attention_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        hidden_states = transformer_outputs[0]\n\n        lm_logits = self.lm_head(hidden_states).permute(1, 0, 2).contiguous()\n\n        loss = None\n        if labels is not None:\n            lm_logits = lm_logits.to(torch.float32)\n\n            # Shift so that tokens < n predict n\n            shift_logits = lm_logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss(ignore_index=self.config.pad_token_id)\n            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))\n\n            lm_logits = lm_logits.to(hidden_states.dtype)\n            loss = loss.to(hidden_states.dtype)\n\n        if not return_dict:\n            output = (lm_logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=lm_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n\n    @staticmethod\n    def _reorder_cache(\n            past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor\n    ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:\n        \"\"\"\n        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or\n        [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct\n        beam_idx at every generation step.\n\n        Output shares the same memory storage as `past`.\n        \"\"\"\n        return tuple(\n            (\n                layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),\n                layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),\n            )\n            for layer_past in past\n        )\n\n    def process_response(self, response):\n        response = response.strip()\n        response = response.replace(\"[[训练时间]]\", \"2023年\")\n        punkts = [\n            [\",\", \"，\"],\n            [\"!\", \"！\"],\n            [\":\", \"：\"],\n            [\";\", \"；\"],\n            [\"\\?\", \"？\"],\n        ]\n        for item in punkts:\n            response = re.sub(r\"([\\u4e00-\\u9fff])%s\" % item[0], r\"\\1%s\" % item[1], response)\n            response = re.sub(r\"%s([\\u4e00-\\u9fff])\" % item[0], r\"%s\\1\" % item[1], response)\n        return response\n\n    @torch.no_grad()\n    def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,\n             do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):\n        if history is None:\n            history = []\n        if logits_processor is None:\n            logits_processor = LogitsProcessorList()\n        logits_processor.append(InvalidScoreLogitsProcessor())\n        gen_kwargs = {\"max_length\": max_length, \"num_beams\": num_beams, \"do_sample\": do_sample, \"top_p\": top_p,\n                      \"temperature\": temperature, \"logits_processor\": logits_processor, **kwargs}\n        if not history:\n            prompt = query\n        else:\n            prompt = \"\"\n            for i, (old_query, response) in enumerate(history):\n                prompt += \"[Round {}]\\n问：{}\\n答：{}\\n\".format(i, old_query, response)\n            prompt += \"[Round {}]\\n问：{}\\n答：\".format(len(history), query)\n        input_ids = tokenizer([prompt], return_tensors=\"pt\", padding=True)\n        input_ids = input_ids.to(self.device)\n        outputs = self.generate(**input_ids, **gen_kwargs)\n        outputs = outputs.tolist()[0][len(input_ids[\"input_ids\"][0]):]\n        response = tokenizer.decode(outputs)\n        response = self.process_response(response)\n        history = history + [(query, response)]\n        return response, history\n\n    @torch.no_grad()\n    def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,\n                    do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):\n        if history is None:\n            history = []\n        if logits_processor is None:\n            logits_processor = LogitsProcessorList()\n        logits_processor.append(InvalidScoreLogitsProcessor())\n        gen_kwargs = {\"max_length\": max_length, \"do_sample\": do_sample, \"top_p\": top_p,\n                      \"temperature\": temperature, \"logits_processor\": logits_processor, **kwargs}\n        if not history:\n            prompt = query\n        else:\n            prompt = \"\"\n            for i, (old_query, response) in enumerate(history):\n                prompt += \"[Round {}]\\n问：{}\\n答：{}\\n\".format(i, old_query, response)\n            prompt += \"[Round {}]\\n问：{}\\n答：\".format(len(history), query)\n        input_ids = tokenizer([prompt], return_tensors=\"pt\", padding=True)\n        input_ids = input_ids.to(self.device)\n        for outputs in self.stream_generate(**input_ids, **gen_kwargs):\n            outputs = outputs.tolist()[0][len(input_ids[\"input_ids\"][0]):]\n            response = tokenizer.decode(outputs)\n            response = self.process_response(response)\n            new_history = history + [(query, response)]\n            yield response, new_history\n\n    @torch.no_grad()\n    def stream_generate(\n            self,\n            input_ids,\n            generation_config: Optional[GenerationConfig] = None,\n            logits_processor: Optional[LogitsProcessorList] = None,\n            stopping_criteria: Optional[StoppingCriteriaList] = None,\n            prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,\n            **kwargs,\n    ):\n        batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]\n\n        if generation_config is None:\n            generation_config = self.generation_config\n        generation_config = copy.deepcopy(generation_config)\n        model_kwargs = generation_config.update(**kwargs)\n        bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id\n\n        if isinstance(eos_token_id, int):\n            eos_token_id = [eos_token_id]\n\n        has_default_max_length = kwargs.get(\"max_length\") is None and generation_config.max_length is not None\n        if has_default_max_length and generation_config.max_new_tokens is None:\n            warnings.warn(\n                f\"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. \"\n                \"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we\"\n                \" recommend using `max_new_tokens` to control the maximum length of the generation.\",\n                UserWarning,\n            )\n        elif generation_config.max_new_tokens is not None:\n            generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length\n            if not has_default_max_length:\n                logger.warn(\n                    f\"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=\"\n                    f\"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. \"\n                    \"Please refer to the documentation for more information. \"\n                    \"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)\",\n                    UserWarning,\n                )\n\n        if input_ids_seq_length >= generation_config.max_length:\n            input_ids_string = \"decoder_input_ids\" if self.config.is_encoder_decoder else \"input_ids\"\n            logger.warning(\n                f\"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to\"\n                f\" {generation_config.max_length}. This can lead to unexpected behavior. You should consider\"\n                \" increasing `max_new_tokens`.\"\n            )\n\n        # 2. Set generation parameters if not already defined\n        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()\n        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()\n\n        logits_processor = self._get_logits_processor(\n            generation_config=generation_config,\n            input_ids_seq_length=input_ids_seq_length,\n            encoder_input_ids=input_ids,\n            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,\n            logits_processor=logits_processor,\n        )\n\n        stopping_criteria = self._get_stopping_criteria(\n            generation_config=generation_config, stopping_criteria=stopping_criteria\n        )\n        logits_warper = self._get_logits_warper(generation_config)\n\n        unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)\n        scores = None\n        while True:\n            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)\n            # forward pass to get next token\n            outputs = self(\n                **model_inputs,\n                return_dict=True,\n                output_attentions=False,\n                output_hidden_states=False,\n            )\n\n            next_token_logits = outputs.logits[:, -1, :]\n\n            # pre-process distribution\n            next_token_scores = logits_processor(input_ids, next_token_logits)\n            next_token_scores = logits_warper(input_ids, next_token_scores)\n\n            # sample\n            probs = nn.functional.softmax(next_token_scores, dim=-1)\n            if generation_config.do_sample:\n                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)\n            else:\n                next_tokens = torch.argmax(probs, dim=-1)\n\n            # update generated ids, model inputs, and length for next step\n            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)\n            model_kwargs = self._update_model_kwargs_for_generation(\n                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder\n            )\n            unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())\n\n            # stop when each sentence is finished, or if we exceed the maximum length\n            if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):\n                break\n            yield input_ids\n\n    def quantize(self, bits: int, empty_init=False, **kwargs):\n        if bits == 0:\n            return\n\n        from .quantization import quantize\n\n        if self.quantized:\n            logger.info(\"Already quantized.\")\n            return self\n\n        self.quantized = True\n\n        self.config.quantization_bit = bits\n\n        self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs)\n        return self"
  },
  {
    "path": "src/resources/models/chatglm-6B/quantization.py",
    "content": "from torch.nn import Linear\nfrom torch.nn.parameter import Parameter\n\nimport bz2\nimport torch\nimport base64\nimport ctypes\nfrom transformers.utils import logging\n\nfrom typing import List\nfrom functools import partial\n\nlogger = logging.get_logger(__name__)\n\ntry:\n    from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up\n\n    class Kernel:\n        def __init__(self, code: bytes, function_names: List[str]):\n            self.code = code\n            self._function_names = function_names\n            self._cmodule = LazyKernelCModule(self.code)\n\n            for name in self._function_names:\n                setattr(self, name, KernelFunction(self._cmodule, name))\n\n    quantization_code = \"$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ\"\n\n    kernels = Kernel(\n        bz2.decompress(base64.b64decode(quantization_code)),\n        [\n            \"int4WeightCompression\",\n            \"int4WeightExtractionFloat\",\n            \"int4WeightExtractionHalf\",\n            \"int8WeightExtractionFloat\",\n            \"int8WeightExtractionHalf\",\n        ],\n    )\nexcept Exception as exception:\n    kernels = None\n    logger.warning(\"Failed to load cpm_kernels:\" + str(exception))\n\n\nclass W8A16Linear(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):\n        ctx.inp_shape = inp.size()\n        ctx.weight_bit_width = weight_bit_width\n        out_features = quant_w.size(0)\n        inp = inp.contiguous().view(-1, inp.size(-1))\n        weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)\n        ctx.weight_shape = weight.size()\n        output = inp.mm(weight.t())\n        ctx.save_for_backward(inp, quant_w, scale_w)\n        return output.view(*(ctx.inp_shape[:-1] + (out_features,)))\n\n    @staticmethod\n    def backward(ctx, grad_output: torch.Tensor):\n        inp, quant_w, scale_w = ctx.saved_tensors\n        weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width)\n        grad_output = grad_output.contiguous().view(-1, weight.size(0))\n        grad_input = grad_output.mm(weight)\n        grad_weight = grad_output.t().mm(inp)\n        return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None\n\n\ndef compress_int4_weight(weight: torch.Tensor):  # (n, m)\n    with torch.cuda.device(weight.device):\n        n, m = weight.size(0), weight.size(1)\n        assert m % 2 == 0\n        m = m // 2\n        out = torch.empty(n, m, dtype=torch.int8, device=\"cuda\")\n        stream = torch.cuda.current_stream()\n\n        gridDim = (n, 1, 1)\n        blockDim = (min(round_up(m, 32), 1024), 1, 1)\n\n        kernels.int4WeightCompression(\n            gridDim,\n            blockDim,\n            0,\n            stream,\n            [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],\n        )\n        return out\n\n\ndef extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):\n    if source_bit_width == 8:\n        func = kernels.int8WeightExtractionHalf\n    elif source_bit_width == 4:\n        func = kernels.int4WeightExtractionHalf\n    else:\n        assert False, \"Unsupported bit-width\"\n\n    with torch.cuda.device(weight.device):\n        n, m = weight.size(0), weight.size(1)\n        out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.half, device=\"cuda\")\n        stream = torch.cuda.current_stream()\n\n        gridDim = (n, 1, 1)\n        blockDim = (min(round_up(m, 32), 1024), 1, 1)\n\n        func(\n            gridDim,\n            blockDim,\n            0,\n            stream,\n            [\n                ctypes.c_void_p(weight.data_ptr()),\n                ctypes.c_void_p(scale_list.data_ptr()),\n                ctypes.c_void_p(out.data_ptr()),\n                ctypes.c_int32(n),\n                ctypes.c_int32(m),\n            ],\n        )\n        return out\n\n\nclass QuantizedLinear(Linear):\n    def __init__(self, weight_bit_width: int, weight_tensor=None, bias_tensor=None, empty_init=False, *args, **kwargs):\n        super(QuantizedLinear, self).__init__(*args, **kwargs)\n        self.weight_bit_width = weight_bit_width\n\n        shape = self.weight.shape\n        del self.weight\n\n        if weight_tensor is None or empty_init:\n            self.weight = torch.empty(\n                shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs[\"device\"]\n            )\n            self.weight_scale = torch.empty(shape[0], dtype=kwargs[\"dtype\"], device=kwargs[\"device\"])\n        else:\n            self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()\n            self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8)\n            if weight_bit_width == 4:\n                self.weight = compress_int4_weight(self.weight)\n\n        self.weight = Parameter(self.weight.to(kwargs[\"device\"]), requires_grad=False)\n        self.weight_scale = Parameter(self.weight_scale.to(kwargs[\"device\"]), requires_grad=False)\n        if bias_tensor is not None:\n            self.bias = Parameter(bias_tensor.to(kwargs[\"device\"]), requires_grad=False)\n        else:\n            self.bias = None\n\n    def forward(self, input):\n        output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width)\n        if self.bias is not None:\n            output = output + self.bias\n        return output\n\n\ndef quantize(model, weight_bit_width, empty_init=False, **kwargs):\n    \"\"\"Replace fp16 linear with quantized linear\"\"\"\n\n    for layer in model.layers:\n        layer.attention.query_key_value = QuantizedLinear(\n            weight_bit_width=weight_bit_width,\n            weight_tensor=layer.attention.query_key_value.weight.to(torch.cuda.current_device()),\n            bias_tensor=layer.attention.query_key_value.bias,\n            in_features=layer.attention.query_key_value.in_features,\n            out_features=layer.attention.query_key_value.out_features,\n            bias=True,\n            dtype=torch.half,\n            device=layer.attention.query_key_value.weight.device,\n            empty_init=empty_init\n        )\n        layer.attention.dense = QuantizedLinear(\n            weight_bit_width=weight_bit_width,\n            weight_tensor=layer.attention.dense.weight.to(torch.cuda.current_device()),\n            bias_tensor=layer.attention.dense.bias,\n            in_features=layer.attention.dense.in_features,\n            out_features=layer.attention.dense.out_features,\n            bias=True,\n            dtype=torch.half,\n            device=layer.attention.dense.weight.device,\n            empty_init=empty_init\n        )\n        layer.mlp.dense_h_to_4h = QuantizedLinear(\n            weight_bit_width=weight_bit_width,\n            weight_tensor=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),\n            bias_tensor=layer.mlp.dense_h_to_4h.bias,\n            in_features=layer.mlp.dense_h_to_4h.in_features,\n            out_features=layer.mlp.dense_h_to_4h.out_features,\n            bias=True,\n            dtype=torch.half,\n            device=layer.mlp.dense_h_to_4h.weight.device,\n            empty_init=empty_init\n        )\n        layer.mlp.dense_4h_to_h = QuantizedLinear(\n            weight_bit_width=weight_bit_width,\n            weight_tensor=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),\n            bias_tensor=layer.mlp.dense_4h_to_h.bias,\n            in_features=layer.mlp.dense_4h_to_h.in_features,\n            out_features=layer.mlp.dense_4h_to_h.out_features,\n            bias=True,\n            dtype=torch.half,\n            device=layer.mlp.dense_4h_to_h.weight.device,\n            empty_init=empty_init\n        )\n    return model"
  },
  {
    "path": "src/resources/models/chatglm-6B/tokenization_chatglm.py",
    "content": "\"\"\"Tokenization classes for ChatGLM.\"\"\"\nimport sys\nimport unicodedata\nfrom typing import List, Optional, Union\nfrom functools import lru_cache\nimport os\nimport collections\nimport re\n\nfrom transformers.tokenization_utils import PreTrainedTokenizer\nfrom icetk.text_tokenizer import TextTokenizer\nfrom icetk.utils import auto_create\nimport icetk.sentencepiece_model_pb2 as sp_model\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\nPRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {\n    \"THUDM/chatglm-6b\": 2048,\n}\n\n\nclass SPTokenizer:\n    def __init__(\n            self,\n            vocab_file,\n            max_blank_length=80,\n            byte_fallback=True,\n    ):\n        assert vocab_file is not None\n        self.vocab_file = vocab_file\n        self.special_tokens = [\"[MASK]\", \"[gMASK]\", \"[sMASK]\", \"<unused_0>\", \"<sop>\", \"<eop>\", \"<ENC>\", \"<dBLOCK>\"]\n        self.max_blank_length = max_blank_length\n        self.byte_fallback = byte_fallback\n        self.text_tokenizer = self._build_text_tokenizer(encode_special_tokens=False)\n        self.special_text_tokenizer = self._build_text_tokenizer(encode_special_tokens=True)\n\n    @staticmethod\n    def _configure_tokenizer(\n            text_tokenizer: TextTokenizer,\n            special_tokens: List[str],\n            max_blank_length: int,\n            byte_fallback: bool,\n            encode_special_tokens=False,\n    ):\n        # special token\n        special_token_type = 4 if encode_special_tokens else 3  # 3 - CONTROL, 4 - USER_DEFINE\n        for token in special_tokens:\n            text_tokenizer.proto.pieces.append(\n                sp_model.ModelProto.SentencePiece(piece=token, score=0.0, type=special_token_type)\n            )\n        # whitespaces\n        for token in [SPTokenizer.get_tab_token()] + [\n            SPTokenizer.get_blank_token(i) for i in range(2, max_blank_length + 1)\n        ]:\n            text_tokenizer.proto.pieces.append(sp_model.ModelProto.SentencePiece(piece=token, score=0.0, type=4))\n        # byte fallback\n        if byte_fallback:\n            text_tokenizer.proto.trainer_spec.byte_fallback = True\n            for i in range(256):\n                text_tokenizer.proto.pieces.append(\n                    sp_model.ModelProto.SentencePiece(piece=\"<0x{:02X}>\".format(i), score=0.0, type=6)\n                )\n        text_tokenizer.refresh()\n\n    def _build_text_tokenizer(self, encode_special_tokens=False):\n        tokenizer = TextTokenizer(self.vocab_file)\n        self._configure_tokenizer(\n            tokenizer, self.special_tokens, self.max_blank_length, self.byte_fallback, encode_special_tokens\n        )\n        return tokenizer\n\n    def _get_text_tokenizer(self, encode_special_tokens=False):\n        if encode_special_tokens:\n            return self.special_text_tokenizer\n        else:\n            return self.text_tokenizer\n\n    @staticmethod\n    def get_blank_token(length: int):\n        assert length >= 2\n        return f\"<|blank_{length}|>\"\n\n    @staticmethod\n    def get_tab_token():\n        return f\"<|tab|>\"\n\n    @property\n    def num_image_tokens(self):\n        return 20000\n\n    @property\n    def num_text_tokens(self):\n        return self.text_tokenizer.num_tokens\n\n    @property\n    def num_tokens(self):\n        return self.num_image_tokens + self.num_text_tokens\n\n    @staticmethod\n    def _encode_whitespaces(text: str, max_len: int = 80):\n        text = text.replace(\"\\t\", SPTokenizer.get_tab_token())\n        for i in range(max_len, 1, -1):\n            text = text.replace(\" \" * i, SPTokenizer.get_blank_token(i))\n        return text\n\n    def _preprocess(self, text: str, linebreak=True, whitespaces=True):\n        if linebreak:\n            text = text.replace(\"\\n\", \"<n>\")\n        if whitespaces:\n            text = self._encode_whitespaces(text, max_len=self.max_blank_length)\n        return text\n\n    def encode(\n            self, text: str, linebreak=True, whitespaces=True, special_tokens=False, add_dummy_prefix=True\n    ) -> List[int]:\n        \"\"\"\n        @param text: Text to encode.\n        @param linebreak: Whether to encode newline (\\n) in text.\n        @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.\n        @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.\n        @param add_dummy_prefix: Whether to add dummy blank space in the beginning.\n        \"\"\"\n        text = self._preprocess(text, linebreak, whitespaces)\n        if not add_dummy_prefix:\n            text = \"<n>\" + text\n        tmp = self._get_text_tokenizer(encode_special_tokens=special_tokens).encode(text)\n        tokens = [x + self.num_image_tokens for x in tmp]\n        return tokens if add_dummy_prefix else tokens[2:]\n\n    def decode(self, text_ids: List[int], special_tokens=False) -> str:\n        ids = [int(_id) - self.num_image_tokens for _id in text_ids]\n        text_tokenizer = self._get_text_tokenizer(encode_special_tokens=special_tokens)\n        if special_tokens:\n            text = \"\".join([text_tokenizer.convert_id_to_token(_id) for _id in ids if _id >= 0])\n        else:\n            ids = [_id for _id in ids if _id >= 0]\n            text = text_tokenizer.decode(ids)\n        text = text.replace(\"<n>\", \"\\n\")\n        text = text.replace(SPTokenizer.get_tab_token(), \"\\t\")\n        for i in range(2, self.max_blank_length + 1):\n            text = text.replace(self.get_blank_token(i), \" \" * i)\n        return text\n\n    def tokenize(\n            self, text: str, linebreak=True, whitespaces=True, special_tokens=False, add_dummy_prefix=True\n    ) -> List[str]:\n        \"\"\"\n        @param text: Text to encode.\n        @param linebreak: Whether to encode newline (\\n) in text.\n        @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.\n        @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.\n        @param add_dummy_prefix: Whether to add dummy blank space in the beginning.\n        \"\"\"\n        text = self._preprocess(text, linebreak, whitespaces)\n        if not add_dummy_prefix:\n            text = \"<n>\" + text\n        tokens = self._get_text_tokenizer(encode_special_tokens=special_tokens).tokenize(text)\n        return tokens if add_dummy_prefix else tokens[2:]\n\n    def __getitem__(self, x: Union[int, str]):\n        if isinstance(x, int):\n            if x < self.num_image_tokens:\n                return \"<image_{}>\".format(x)\n            else:\n                return self.text_tokenizer.convert_id_to_token(x - self.num_image_tokens)\n        elif isinstance(x, str):\n            if x.startswith(\"<image_\") and x.endswith(\">\") and x[7:-1].isdigit():\n                return int(x[7:-1])\n            else:\n                return self.text_tokenizer.convert_token_to_id(x) + self.num_image_tokens\n        else:\n            raise ValueError(\"The key should be str or int.\")\n\n\nclass ChatGLMTokenizer(PreTrainedTokenizer):\n    \"\"\"\n    Construct a ChatGLM tokenizer. Based on byte-level Byte-Pair-Encoding.\n\n    Args:\n        vocab_file (`str`):\n            Path to the vocabulary file.\n    \"\"\"\n\n    vocab_files_names = {\"vocab_file\": \"ice_text.model\"}\n    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES\n    model_input_names = [\"input_ids\"]\n\n    def __init__(\n            self,\n            vocab_file,\n            do_lower_case=False,\n            remove_space=False,\n            bos_token='sop',\n            eos_token='eos',\n            eop_token='eop',\n            mask_token='[MASK]',\n            gmask_token='[gMASK]',\n            padding_side=\"right\",\n            **kwargs\n    ) -> None:\n        super().__init__(\n            do_lower_case=do_lower_case,\n            remove_space=remove_space,\n            padding_side=padding_side,\n            **kwargs\n        )\n\n        self.do_lower_case = do_lower_case\n        self.remove_space = remove_space\n        self.vocab_file = vocab_file\n\n        self.bos_token = bos_token\n        self.eos_token = eos_token\n        self.eop_token = eop_token\n        self.mask_token = mask_token\n        self.gmask_token = gmask_token\n\n        self.sp_tokenizer = SPTokenizer(vocab_file)\n\n        \"\"\" Initialisation \"\"\"\n\n    @property\n    def eop_token_id(self) -> Optional[int]:\n        \"\"\"\n        `Optional[int]`: Id of the end of sentence token in the vocabulary. Returns `None` if the token has not been\n        set.\n        \"\"\"\n        if self.eop_token is None:\n            return None\n        return self.convert_tokens_to_ids(self.eop_token)\n\n    @property\n    def gmask_token_id(self) -> Optional[int]:\n        \"\"\"\n        `Optional[int]`: Id of the end of sentence token in the vocabulary. Returns `None` if the token has not been\n        set.\n        \"\"\"\n        if self.gmask_token is None:\n            return None\n        return self.convert_tokens_to_ids(self.gmask_token)\n\n    @property\n    def vocab_size(self):\n        \"\"\" Returns vocab size \"\"\"\n        return self.sp_tokenizer.num_tokens\n\n    def get_vocab(self):\n        \"\"\" Returns vocab as a dict \"\"\"\n        vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}\n        vocab.update(self.added_tokens_encoder)\n        return vocab\n\n    def preprocess_text(self, inputs):\n        if self.remove_space:\n            outputs = \" \".join(inputs.strip().split())\n        else:\n            outputs = inputs\n\n        if self.do_lower_case:\n            outputs = outputs.lower()\n\n        return outputs\n\n    def _tokenize(self, text, **kwargs):\n        \"\"\" Returns a tokenized string. \"\"\"\n        text = self.preprocess_text(text)\n\n        seq = self.sp_tokenizer.tokenize(text)\n\n        return seq\n\n    def decode(\n            self,\n            token_ids: Union[List[int], List[List[int]]],\n            skip_special_tokens: bool = False,\n            clean_up_tokenization_spaces: bool = True,\n            spaces_between_special_tokens: bool = True,\n            **kwargs\n    ) -> str:\n        if isinstance(token_ids[0], list):\n            tokens = []\n            for single_token_ids in token_ids:\n                if skip_special_tokens and self.pad_token_id in single_token_ids:  # remove pad\n                    single_token_ids = list(filter((self.pad_token_id).__ne__, single_token_ids))\n                tokens.append(self.sp_tokenizer.decode(single_token_ids, not skip_special_tokens))\n            return (tokens)\n        else:\n            if skip_special_tokens and self.pad_token_id in token_ids:  # remove pad\n                token_ids = list(filter((self.pad_token_id).__ne__, token_ids))\n            return self.sp_tokenizer.decode(token_ids, not skip_special_tokens)\n\n    def _convert_token_to_id(self, token):\n        \"\"\" Converts a token (str) in an id using the vocab. \"\"\"\n        return self.sp_tokenizer[token]\n\n    def _convert_id_to_token(self, index):\n        \"\"\"Converts an index (integer) in a token (str) using the vocab.\"\"\"\n        return self.sp_tokenizer[index]\n\n    def save_vocabulary(self, save_directory, filename_prefix=None):\n        \"\"\"\n        Save the vocabulary and special tokens file to a directory.\n\n        Args:\n            save_directory (`str`):\n                The directory in which to save the vocabulary.\n            filename_prefix (`str`, *optional*):\n                An optional prefix to add to the named of the saved files.\n\n        Returns:\n            `Tuple(str)`: Paths to the files saved.\n        \"\"\"\n        if os.path.isdir(save_directory):\n            vocab_file = os.path.join(\n                save_directory, self.vocab_files_names[\"vocab_file\"]\n            )\n        else:\n            vocab_file = save_directory\n\n        with open(self.vocab_file, 'rb') as fin:\n            proto_str = fin.read()\n\n        with open(vocab_file, \"wb\") as writer:\n            writer.write(proto_str)\n\n        return (vocab_file,)\n\n    # def num_special_tokens_to_add(self, pair: bool = False) -> int:\n    #     \"\"\"\n    #     Returns the number of added tokens when encoding a sequence with special tokens.\n    #\n    #     <Tip>\n    #\n    #     This encodes a dummy input and checks the number of added tokens, and is therefore not efficient. Do not put\n    #     this inside your training loop.\n    #\n    #     </Tip>\n    #\n    #     Args:\n    #         pair (`bool`, *optional*, defaults to `False`):\n    #             Whether the number of added tokens should be computed in the case of a sequence pair or a single\n    #             sequence.\n    #\n    #     Returns:\n    #         `int`: Number of special tokens added to sequences.\n    #     \"\"\"\n    #     # token_ids_0 = []\n    #     # token_ids_1 = []\n    #     # return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None))\n    #     return 2\n\n    def build_inputs_with_special_tokens(\n            self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None\n    ) -> List[int]:\n        \"\"\"\n        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and\n        adding special tokens. A BERT sequence has the following format:\n\n        - single sequence: `[CLS] X [SEP]`\n        - pair of sequences: `[CLS] A [SEP] B [SEP]`\n\n        Args:\n            token_ids_0 (`List[int]`):\n                List of IDs to which the special tokens will be added.\n            token_ids_1 (`List[int]`, *optional*):\n                Optional second list of IDs for sequence pairs.\n\n        Returns:\n            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.\n        \"\"\"\n        mask_id = self.sp_tokenizer[self.mask_token]\n        gmask_id = self.sp_tokenizer[self.gmask_token]\n        eos_id = self.sp_tokenizer[self.eos_token]\n        bos_id = self.sp_tokenizer[self.bos_token]\n        eop_id = self.sp_tokenizer[self.eop_token]\n\n        if mask_id not in token_ids_0 and gmask_id not in token_ids_0:\n            token_ids_0 += [gmask_id]\n\n        if token_ids_0[-1] != mask_id and token_ids_0[-1] != gmask_id:\n            token_ids_0 += [eos_id]\n\n        token_ids_0 += [bos_id]\n\n        if token_ids_1 is not None:\n            token_ids_0 += token_ids_1 + [eop_id]\n\n        return token_ids_0"
  },
  {
    "path": "src/resources/models/chatglm-6B/tokenizer_config.json",
    "content": "{\n  \"name_or_path\": \"THUDM/chatglm-6b\",\n  \"bos_token\": \"<sop>\",\n  \"eop_token\": \"<eop>\",\n  \"eos_token\": \"</s>\",\n  \"gmask_token\": \"[gMASK]\",\n  \"mask_token\": \"[MASK]\",\n  \"pad_token\": \"<pad>\",\n  \"unk_token\": \"<unk>\",\n  \"remove_space\": false,\n  \"do_lower_case\": false,\n  \"tokenizer_class\": \"ChatGLMTokenizer\",\n  \"auto_map\": {\n    \"AutoTokenizer\": [\n      \"tokenization_chatglm.ChatGLMTokenizer\",\n      null\n      ]\n  }\n}\n"
  },
  {
    "path": "src/resources/models/chatglm2-6B/config.json",
    "content": "{\n  \"model_type\": \"chatglm\",\n  \"architectures\": [\n    \"ChatGLMModel\"\n  ],\n  \"auto_map\": {\n    \"AutoConfig\": \"configuration_chatglm.ChatGLMConfig\",\n    \"AutoModel\": \"modeling_chatglm.ChatGLMForConditionalGeneration\",\n    \"AutoModelForSeq2SeqLM\": \"modeling_chatglm.ChatGLMForConditionalGeneration\"\n  },\n  \"add_bias_linear\": false,\n  \"add_qkv_bias\": true,\n  \"apply_query_key_layer_scaling\": true,\n  \"apply_residual_connection_post_layernorm\": false,\n  \"attention_dropout\": 0.0,\n  \"attention_softmax_in_fp32\": true,\n  \"bias_dropout_fusion\": true,\n  \"ffn_hidden_size\": 13696,\n  \"fp32_residual_connection\": false,\n  \"hidden_dropout\": 0.0,\n  \"hidden_size\": 4096,\n  \"kv_channels\": 128,\n  \"layernorm_epsilon\": 1e-05,\n  \"multi_query_attention\": true,\n  \"multi_query_group_num\": 2,\n  \"num_attention_heads\": 32,\n  \"num_layers\": 28,\n  \"original_rope\": true,\n  \"padded_vocab_size\": 65024,\n  \"post_layer_norm\": true,\n  \"rmsnorm\": true,\n  \"seq_length\": 32768,\n  \"use_cache\": true,\n  \"torch_dtype\": \"float16\",\n  \"transformers_version\": \"4.27.1\",\n  \"tie_word_embeddings\": false,\n  \"eos_token_id\": 2,\n  \"pad_token_id\": 2\n}"
  },
  {
    "path": "src/resources/models/chatglm2-6B/configuration_chatglm.py",
    "content": "from transformers import PretrainedConfig\n\n\nclass ChatGLMConfig(PretrainedConfig):\n    def __init__(\n        self,\n        num_layers=28,\n        padded_vocab_size=65024,\n        hidden_size=4096,\n        ffn_hidden_size=13696,\n        kv_channels=128,\n        num_attention_heads=32,\n        seq_length=2048,\n        hidden_dropout=0.0,\n        attention_dropout=0.0,\n        layernorm_epsilon=1e-5,\n        rmsnorm=True,\n        apply_residual_connection_post_layernorm=False,\n        post_layer_norm=True,\n        add_bias_linear=False,\n        add_qkv_bias=False,\n        interleaved_qkv=False,\n        bias_dropout_fusion=True,\n        multi_query_attention=False,\n        multi_query_group_num=1,\n        apply_query_key_layer_scaling=True,\n        attention_softmax_in_fp32=True,\n        fp32_residual_connection=False,\n        quantization_bit=0,\n        **kwargs\n    ):\n        self.num_layers = num_layers\n        self.padded_vocab_size = padded_vocab_size\n        self.hidden_size = hidden_size\n        self.ffn_hidden_size = ffn_hidden_size\n        self.kv_channels = kv_channels\n        self.num_attention_heads = num_attention_heads\n        self.seq_length = seq_length\n        self.hidden_dropout = hidden_dropout\n        self.attention_dropout = attention_dropout\n        self.layernorm_epsilon = layernorm_epsilon\n        self.rmsnorm = rmsnorm\n        self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm\n        self.post_layer_norm = post_layer_norm\n        self.add_bias_linear = add_bias_linear\n        self.add_qkv_bias = add_qkv_bias\n        self.bias_dropout_fusion = bias_dropout_fusion\n        self.multi_query_attention = multi_query_attention\n        self.multi_query_group_num = multi_query_group_num\n        self.apply_query_key_layer_scaling = apply_query_key_layer_scaling\n        self.attention_softmax_in_fp32 = attention_softmax_in_fp32\n        self.fp32_residual_connection = fp32_residual_connection\n        self.quantization_bit = quantization_bit\n        super().__init__(**kwargs)"
  },
  {
    "path": "src/resources/models/chatglm2-6B/modeling_chatglm.py",
    "content": "\"\"\" PyTorch ChatGLM model. \"\"\"\n\nimport math\nimport copy\nimport warnings\nimport re\nimport sys\n\nimport torch\nimport torch.utils.checkpoint\nimport torch.nn.functional as F\nfrom torch import nn\nfrom torch.nn import CrossEntropyLoss, LayerNorm\nfrom torch.nn.utils import skip_init\nfrom typing import Optional, Tuple, Union, List, Callable, Dict, Any\n\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n)\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.utils import logging\nfrom transformers.generation.logits_process import LogitsProcessor\nfrom transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput\n\nfrom .configuration_chatglm import ChatGLMConfig\n\n# flags required to enable jit fusion kernels\n\nif sys.platform != 'darwin':\n    torch._C._jit_set_profiling_mode(False)\n    torch._C._jit_set_profiling_executor(False)\n    torch._C._jit_override_can_fuse_on_cpu(True)\n    torch._C._jit_override_can_fuse_on_gpu(True)\n\nlogger = logging.get_logger(__name__)\n\n_CHECKPOINT_FOR_DOC = \"THUDM/ChatGLM2-6B\"\n_CONFIG_FOR_DOC = \"ChatGLM6BConfig\"\n\nCHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [\n    \"THUDM/chatglm2-6b\",\n    # See all ChatGLM models at https://huggingface.co/models?filter=chatglm\n]\n\n\ndef default_init(cls, *args, **kwargs):\n    return cls(*args, **kwargs)\n\n\nclass InvalidScoreLogitsProcessor(LogitsProcessor):\n    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:\n        if torch.isnan(scores).any() or torch.isinf(scores).any():\n            scores.zero_()\n            scores[..., 5] = 5e4\n        return scores\n\n\ndef split_tensor_along_last_dim(\n        tensor: torch.Tensor,\n        num_partitions: int,\n        contiguous_split_chunks: bool = False,\n) -> List[torch.Tensor]:\n    \"\"\"Split a tensor along its last dimension.\n\n    Arguments:\n        tensor: input tensor.\n        num_partitions: number of partitions to split the tensor\n        contiguous_split_chunks: If True, make each chunk contiguous\n                                 in memory.\n\n    Returns:\n        A list of Tensors\n    \"\"\"\n    # Get the size and dimension.\n    last_dim = tensor.dim() - 1\n    last_dim_size = tensor.size()[last_dim] // num_partitions\n    # Split.\n    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)\n    # Note: torch.split does not create contiguous tensors by default.\n    if contiguous_split_chunks:\n        return tuple(chunk.contiguous() for chunk in tensor_list)\n\n    return tensor_list\n\n\nclass RotaryEmbedding(nn.Module):\n    def __init__(self, dim, original_impl=False, device=None, dtype=None):\n        super().__init__()\n        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device, dtype=dtype) / dim))\n        self.register_buffer(\"inv_freq\", inv_freq)\n        self.dim = dim\n        self.original_impl = original_impl\n\n    def forward_impl(\n            self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000\n    ):\n        \"\"\"Enhanced Transformer with Rotary Position Embedding.\n\n        Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/\n        transformers/rope/__init__.py. MIT License:\n        https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.\n        \"\"\"\n        # $\\Theta = {\\theta_i = 10000^{\\frac{2(i-1)}{d}}, i \\in [1, 2, ..., \\frac{d}{2}]}$\n        theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))\n\n        # Create position indexes `[0, 1, ..., seq_len - 1]`\n        seq_idx = torch.arange(seq_len, dtype=dtype, device=device)\n\n        # Calculate the product of position index and $\\theta_i$\n        idx_theta = torch.outer(seq_idx, theta).float()\n\n        cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)\n\n        # this is to mimic the behaviour of complex32, else we will get different results\n        if dtype in (torch.float16, torch.bfloat16, torch.int8):\n            cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()\n        return cache\n\n    def forward(self, max_seq_len, offset=0):\n        return self.forward_impl(\n            max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device\n        )\n\n\n@torch.jit.script\ndef apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:\n    # x: [sq, b, np, hn]\n    sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)\n    rot_dim = rope_cache.shape[-2] * 2\n    x, x_pass = x[..., :rot_dim], x[..., rot_dim:]\n    # truncate to support variable sizes\n    rope_cache = rope_cache[:sq]\n    xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)\n    rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)\n    x_out2 = torch.stack(\n        [\n            xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],\n            xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],\n        ],\n        -1,\n    )\n    x_out2 = x_out2.flatten(3)\n    return torch.cat((x_out2, x_pass), dim=-1)\n\n\nclass RMSNorm(torch.nn.Module):\n    def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):\n        super().__init__()\n        self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))\n        self.eps = eps\n\n    def forward(self, hidden_states: torch.Tensor):\n        input_dtype = hidden_states.dtype\n        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)\n\n        return (self.weight * hidden_states).to(input_dtype)\n\n\nclass CoreAttention(torch.nn.Module):\n    def __init__(self, config: ChatGLMConfig, layer_number):\n        super(CoreAttention, self).__init__()\n\n        self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling\n        self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32\n        if self.apply_query_key_layer_scaling:\n            self.attention_softmax_in_fp32 = True\n        self.layer_number = max(1, layer_number)\n\n        projection_size = config.kv_channels * config.num_attention_heads\n\n        # Per attention head and per partition values.\n        self.hidden_size_per_partition = projection_size\n        self.hidden_size_per_attention_head = projection_size // config.num_attention_heads\n        self.num_attention_heads_per_partition = config.num_attention_heads\n\n        coeff = None\n        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)\n        if self.apply_query_key_layer_scaling:\n            coeff = self.layer_number\n            self.norm_factor *= coeff\n        self.coeff = coeff\n\n        self.attention_dropout = torch.nn.Dropout(config.attention_dropout)\n\n    def forward(self, query_layer, key_layer, value_layer, attention_mask):\n        pytorch_major_version = int(torch.__version__.split('.')[0])\n        if pytorch_major_version >= 2:\n            query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]\n            if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:\n                context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,\n                                                                                 is_causal=True)\n            else:\n                if attention_mask is not None and attention_mask.dtype == torch.bool:\n                    attention_mask = ~attention_mask\n                context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,\n                                                                                 attention_mask)\n            context_layer = context_layer.permute(2, 0, 1, 3)\n            new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)\n            context_layer = context_layer.reshape(*new_context_layer_shape)\n        else:\n            # Raw attention scores\n\n            # [b, np, sq, sk]\n            output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))\n\n            # [sq, b, np, hn] -> [sq, b * np, hn]\n            query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)\n            # [sk, b, np, hn] -> [sk, b * np, hn]\n            key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)\n\n            # preallocting input tensor: [b * np, sq, sk]\n            matmul_input_buffer = torch.empty(\n                output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,\n                device=query_layer.device\n            )\n\n            # Raw attention scores. [b * np, sq, sk]\n            matmul_result = torch.baddbmm(\n                matmul_input_buffer,\n                query_layer.transpose(0, 1),  # [b * np, sq, hn]\n                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]\n                beta=0.0,\n                alpha=(1.0 / self.norm_factor),\n            )\n\n            # change view to [b, np, sq, sk]\n            attention_scores = matmul_result.view(*output_size)\n\n            # ===========================\n            # Attention probs and dropout\n            # ===========================\n\n            # attention scores and attention mask [b, np, sq, sk]\n            if self.attention_softmax_in_fp32:\n                attention_scores = attention_scores.float()\n            if self.coeff is not None:\n                attention_scores = attention_scores * self.coeff\n            if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:\n                attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],\n                                            device=attention_scores.device, dtype=torch.bool)\n                attention_mask.tril_()\n                attention_mask = ~attention_mask\n            if attention_mask is not None:\n                if attention_mask.dtype == torch.bool:\n                    attention_scores = attention_scores.masked_fill(attention_mask, float(\"-inf\"))\n                else:\n                    attention_scores = attention_mask\n            attention_probs = F.softmax(attention_scores, dim=-1)\n            attention_probs = attention_probs.type_as(value_layer)\n\n            # This is actually dropping out entire tokens to attend to, which might\n            # seem a bit unusual, but is taken from the original Transformer paper.\n            attention_probs = self.attention_dropout(attention_probs)\n            # =========================\n            # Context layer. [sq, b, hp]\n            # =========================\n\n            # value_layer -> context layer.\n            # [sk, b, np, hn] --> [b, np, sq, hn]\n\n            # context layer shape: [b, np, sq, hn]\n            output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))\n            # change view [sk, b * np, hn]\n            value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)\n            # change view [b * np, sq, sk]\n            attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)\n            # matmul: [b * np, sq, hn]\n            context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))\n            # change view [b, np, sq, hn]\n            context_layer = context_layer.view(*output_size)\n            # [b, np, sq, hn] --> [sq, b, np, hn]\n            context_layer = context_layer.permute(2, 0, 1, 3).contiguous()\n            # [sq, b, np, hn] --> [sq, b, hp]\n            new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)\n            context_layer = context_layer.view(*new_context_layer_shape)\n\n        return context_layer\n\n\nclass SelfAttention(torch.nn.Module):\n    \"\"\"Parallel self-attention layer abstract class.\n\n    Self-attention layer takes input with size [s, b, h]\n    and returns output of the same size.\n    \"\"\"\n\n    def __init__(self, config: ChatGLMConfig, layer_number, device=None):\n        super(SelfAttention, self).__init__()\n        self.layer_number = max(1, layer_number)\n\n        self.projection_size = config.kv_channels * config.num_attention_heads\n\n        # Per attention head and per partition values.\n        self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads\n        self.num_attention_heads_per_partition = config.num_attention_heads\n\n        self.multi_query_attention = config.multi_query_attention\n        self.qkv_hidden_size = 3 * self.projection_size\n        if self.multi_query_attention:\n            self.num_multi_query_groups_per_partition = config.multi_query_group_num\n            self.qkv_hidden_size = (\n                    self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num\n            )\n        self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,\n                                         bias=config.add_bias_linear or config.add_qkv_bias,\n                                         device=device, **_config_to_kwargs(config)\n                                         )\n\n        self.core_attention = CoreAttention(config, self.layer_number)\n\n        # Output.\n        self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,\n                               device=device, **_config_to_kwargs(config)\n                               )\n\n    def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):\n        if self.multi_query_attention:\n            num_attention_heads = self.num_multi_query_groups_per_partition\n        else:\n            num_attention_heads = self.num_attention_heads_per_partition\n        return torch.empty(\n            inference_max_sequence_len,\n            batch_size,\n            num_attention_heads,\n            self.hidden_size_per_attention_head,\n            dtype=dtype,\n            device=device,\n        )\n\n    def forward(\n            self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True\n    ):\n        # hidden_states: [sq, b, h]\n\n        # =================================================\n        # Pre-allocate memory for key-values for inference.\n        # =================================================\n        # =====================\n        # Query, Key, and Value\n        # =====================\n\n        # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]\n        mixed_x_layer = self.query_key_value(hidden_states)\n\n        if self.multi_query_attention:\n            (query_layer, key_layer, value_layer) = mixed_x_layer.split(\n                [\n                    self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,\n                    self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,\n                    self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,\n                ],\n                dim=-1,\n            )\n            query_layer = query_layer.view(\n                query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)\n            )\n            key_layer = key_layer.view(\n                key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)\n            )\n            value_layer = value_layer.view(\n                value_layer.size()[:-1]\n                + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)\n            )\n        else:\n            new_tensor_shape = mixed_x_layer.size()[:-1] + \\\n                               (self.num_attention_heads_per_partition,\n                                3 * self.hidden_size_per_attention_head)\n            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)\n\n            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]\n            (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)\n\n        # apply relative positional encoding (rotary embedding)\n        if rotary_pos_emb is not None:\n            query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)\n            key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)\n\n        # adjust key and value for inference\n        if use_cache:\n            if kv_cache is not None:\n                cache_k, cache_v = kv_cache\n                key_layer = torch.cat((cache_k, key_layer), dim=0)\n                value_layer = torch.cat((cache_v, value_layer), dim=0)\n            kv_cache = (key_layer, value_layer)\n        else:\n            kv_cache = None\n\n        if self.multi_query_attention:\n            key_layer = key_layer.unsqueeze(-2)\n            key_layer = key_layer.expand(\n                -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1\n            )\n            key_layer = key_layer.contiguous().view(\n                key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)\n            )\n            value_layer = value_layer.unsqueeze(-2)\n            value_layer = value_layer.expand(\n                -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1\n            )\n            value_layer = value_layer.contiguous().view(\n                value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)\n            )\n\n        # ==================================\n        # core attention computation\n        # ==================================\n\n        context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)\n\n        # =================\n        # Output. [sq, b, h]\n        # =================\n\n        output = self.dense(context_layer)\n\n        return output, kv_cache\n\n\ndef _config_to_kwargs(args):\n    common_kwargs = {\n        \"dtype\": args.torch_dtype,\n    }\n    return common_kwargs\n\n\nclass MLP(torch.nn.Module):\n    \"\"\"MLP.\n\n    MLP will take the input with h hidden state, project it to 4*h\n    hidden dimension, perform nonlinear transformation, and project the\n    state back into h hidden dimension.\n    \"\"\"\n\n    def __init__(self, config: ChatGLMConfig, device=None):\n        super(MLP, self).__init__()\n\n        self.add_bias = config.add_bias_linear\n\n        # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf\n        self.dense_h_to_4h = nn.Linear(\n            config.hidden_size,\n            config.ffn_hidden_size * 2,\n            bias=self.add_bias,\n            device=device,\n            **_config_to_kwargs(config)\n        )\n\n        def swiglu(x):\n            x = torch.chunk(x, 2, dim=-1)\n            return F.silu(x[0]) * x[1]\n\n        self.activation_func = swiglu\n\n        # Project back to h.\n        self.dense_4h_to_h = nn.Linear(\n            config.ffn_hidden_size,\n            config.hidden_size,\n            bias=self.add_bias,\n            device=device,\n            **_config_to_kwargs(config)\n        )\n\n    def forward(self, hidden_states):\n        # [s, b, 4hp]\n        intermediate_parallel = self.dense_h_to_4h(hidden_states)\n        intermediate_parallel = self.activation_func(intermediate_parallel)\n        # [s, b, h]\n        output = self.dense_4h_to_h(intermediate_parallel)\n        return output\n\n\nclass GLMBlock(torch.nn.Module):\n    \"\"\"A single transformer layer.\n\n    Transformer layer takes input with size [s, b, h] and returns an\n    output of the same size.\n    \"\"\"\n\n    def __init__(self, config: ChatGLMConfig, layer_number, device=None):\n        super(GLMBlock, self).__init__()\n        self.layer_number = layer_number\n\n        self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm\n\n        self.fp32_residual_connection = config.fp32_residual_connection\n\n        LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm\n        # Layernorm on the input data.\n        self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,\n                                             dtype=config.torch_dtype)\n\n        # Self attention.\n        self.self_attention = SelfAttention(config, layer_number, device=device)\n        self.hidden_dropout = config.hidden_dropout\n\n        # Layernorm on the attention output\n        self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,\n                                                      dtype=config.torch_dtype)\n\n        # MLP\n        self.mlp = MLP(config, device=device)\n\n    def forward(\n            self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,\n    ):\n        # hidden_states: [s, b, h]\n\n        # Layer norm at the beginning of the transformer layer.\n        layernorm_output = self.input_layernorm(hidden_states)\n        # Self attention.\n        attention_output, kv_cache = self.self_attention(\n            layernorm_output,\n            attention_mask,\n            rotary_pos_emb,\n            kv_cache=kv_cache,\n            use_cache=use_cache\n        )\n\n        # Residual connection.\n        if self.apply_residual_connection_post_layernorm:\n            residual = layernorm_output\n        else:\n            residual = hidden_states\n\n        layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)\n        layernorm_input = residual + layernorm_input\n\n        # Layer norm post the self attention.\n        layernorm_output = self.post_attention_layernorm(layernorm_input)\n\n        # MLP.\n        mlp_output = self.mlp(layernorm_output)\n\n        # Second residual connection.\n        if self.apply_residual_connection_post_layernorm:\n            residual = layernorm_output\n        else:\n            residual = layernorm_input\n\n        output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)\n        output = residual + output\n\n        return output, kv_cache\n\n\nclass GLMTransformer(torch.nn.Module):\n    \"\"\"Transformer class.\"\"\"\n\n    def __init__(self, config: ChatGLMConfig, device=None):\n        super(GLMTransformer, self).__init__()\n\n        self.fp32_residual_connection = config.fp32_residual_connection\n        self.post_layer_norm = config.post_layer_norm\n\n        # Number of layers.\n        self.num_layers = config.num_layers\n\n        # Transformer layers.\n        def build_layer(layer_number):\n            return GLMBlock(config, layer_number, device=device)\n\n        self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])\n\n        if self.post_layer_norm:\n            LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm\n            # Final layer norm before output.\n            self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,\n                                                 dtype=config.torch_dtype)\n\n        self.gradient_checkpointing = False\n\n    def _get_layer(self, layer_number):\n        return self.layers[layer_number]\n\n    def forward(\n            self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,\n            use_cache: Optional[bool] = True,\n            output_hidden_states: Optional[bool] = False,\n    ):\n        if not kv_caches:\n            kv_caches = [None for _ in range(self.num_layers)]\n        presents = () if use_cache else None\n        all_self_attentions = None\n        all_hidden_states = () if output_hidden_states else None\n        for index in range(self.num_layers):\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_states,)\n\n            layer = self._get_layer(index)\n\n            if self.gradient_checkpointing and self.training:\n                hidden_states, kv_cache = torch.utils.checkpoint.checkpoint(\n                    layer,\n                    hidden_states,\n                    attention_mask,\n                    rotary_pos_emb,\n                    kv_caches[index],\n                    use_cache\n                )\n            else:\n                hidden_states, kv_cache = layer(\n                    hidden_states,\n                    attention_mask,\n                    rotary_pos_emb,\n                    kv_cache=kv_caches[index],\n                    use_cache=use_cache\n                )\n\n            if use_cache:\n                presents = presents + (kv_cache,)\n\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_states,)\n\n        # Final layer norm.\n        if self.post_layer_norm:\n            hidden_states = self.final_layernorm(hidden_states)\n\n        return hidden_states, presents, all_hidden_states, all_self_attentions\n\n\nclass ChatGLMPreTrainedModel(PreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and\n    a simple interface for downloading and loading pretrained models.\n    \"\"\"\n\n    is_parallelizable = False\n    supports_gradient_checkpointing = True\n    config_class = ChatGLMConfig\n    base_model_prefix = \"transformer\"\n    _no_split_modules = [\"GLMBlock\"]\n\n    def _init_weights(self, module: nn.Module):\n        \"\"\"Initialize the weights.\"\"\"\n        return\n\n    def get_masks(self, input_ids, past_key_values, padding_mask=None):\n        batch_size, seq_length = input_ids.shape\n        full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)\n        full_attention_mask.tril_()\n        past_length = 0\n        if past_key_values:\n            past_length = past_key_values[0][0].shape[0]\n        if past_length:\n            full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,\n                                                        device=input_ids.device), full_attention_mask), dim=-1)\n        if padding_mask is not None:\n            full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)\n        if not past_length and padding_mask is not None:\n            full_attention_mask -= padding_mask.unsqueeze(-1) - 1\n        full_attention_mask = (full_attention_mask < 0.5).bool()\n        full_attention_mask.unsqueeze_(1)\n        return full_attention_mask\n\n    def get_position_ids(self, input_ids, device):\n        batch_size, seq_length = input_ids.shape\n        position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)\n        return position_ids\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, ChatGLMModel):\n            module.gradient_checkpointing = value\n            if hasattr(module, \"encoder\"):\n                encoder = getattr(module, \"encoder\")\n                encoder.gradient_checkpointing = value\n\n\nclass Embedding(torch.nn.Module):\n    \"\"\"Language model embeddings.\"\"\"\n\n    def __init__(self, config: ChatGLMConfig, device=None):\n        super(Embedding, self).__init__()\n\n        self.hidden_size = config.hidden_size\n        # Word embeddings (parallel).\n        self.word_embeddings = nn.Embedding(\n            config.padded_vocab_size,\n            self.hidden_size,\n            dtype=config.torch_dtype,\n            device=device\n        )\n        self.fp32_residual_connection = config.fp32_residual_connection\n\n    def forward(self, input_ids):\n        # Embeddings.\n        words_embeddings = self.word_embeddings(input_ids)\n        embeddings = words_embeddings\n        # Data format change to avoid explicit tranposes : [b s h] --> [s b h].\n        embeddings = embeddings.transpose(0, 1).contiguous()\n        # If the input flag for fp32 residual connection is set, convert for float.\n        if self.fp32_residual_connection:\n            embeddings = embeddings.float()\n        return embeddings\n\n\nclass ChatGLMModel(ChatGLMPreTrainedModel):\n    def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):\n        super().__init__(config)\n        if empty_init:\n            init_method = skip_init\n        else:\n            init_method = default_init\n        init_kwargs = {}\n        if device is not None:\n            init_kwargs[\"device\"] = device\n        self.embedding = init_method(Embedding, config, **init_kwargs)\n\n        # Rotary positional embeddings\n        self.seq_length = config.seq_length\n        rotary_dim = (\n            config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels\n        )\n\n        self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,\n                                              dtype=config.torch_dtype)\n        self.encoder = init_method(GLMTransformer, config, **init_kwargs)\n        self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,\n                                        dtype=config.torch_dtype, **init_kwargs)\n        self.gradient_checkpointing = False\n\n    def get_input_embeddings(self):\n        return self.embedding\n\n    def forward(\n            self,\n            input_ids,\n            position_ids: Optional[torch.Tensor] = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            full_attention_mask: Optional[torch.Tensor] = None,\n            past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,\n            inputs_embeds: Optional[torch.Tensor] = None,\n            use_cache: Optional[bool] = None,\n            output_hidden_states: Optional[bool] = None,\n            return_dict: Optional[bool] = None,\n    ):\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                # logger.warning_once(\n                #     \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                # )\n                use_cache = False\n\n        batch_size, seq_length = input_ids.shape\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embedding(input_ids)\n\n        if full_attention_mask is None:\n            if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):\n                full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)\n\n        # Rotary positional embeddings\n        rotary_pos_emb = self.rotary_pos_emb(self.seq_length)\n        if position_ids is not None:\n            rotary_pos_emb = rotary_pos_emb[position_ids]\n        else:\n            rotary_pos_emb = rotary_pos_emb[None, :seq_length]\n        rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()\n\n        # Run encoder.\n        hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(\n            inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,\n            kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states\n        )\n\n        if not return_dict:\n            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)\n\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=presents,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attentions,\n        )\n\n    def quantize(self, weight_bit_width: int):\n        from .quantization import quantize\n        quantize(self.encoder, weight_bit_width)\n        return self\n\n\nclass ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):\n    def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):\n        super().__init__(config)\n\n        self.max_sequence_length = config.max_length\n        self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)\n        self.config = config\n        self.quantized = False\n\n        if self.config.quantization_bit:\n            self.quantize(self.config.quantization_bit, empty_init=True)\n\n    def _update_model_kwargs_for_generation(\n            self,\n            outputs: ModelOutput,\n            model_kwargs: Dict[str, Any],\n            is_encoder_decoder: bool = False,\n            standardize_cache_format: bool = False,\n    ) -> Dict[str, Any]:\n        # update past_key_values\n        model_kwargs[\"past_key_values\"] = self._extract_past_from_model_output(\n            outputs, standardize_cache_format=standardize_cache_format\n        )\n\n        # update attention mask\n        if \"attention_mask\" in model_kwargs:\n            attention_mask = model_kwargs[\"attention_mask\"]\n            model_kwargs[\"attention_mask\"] = torch.cat(\n                [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1\n            )\n\n        # update position ids\n        if \"position_ids\" in model_kwargs:\n            position_ids = model_kwargs[\"position_ids\"]\n            new_position_id = position_ids[..., -1:].clone()\n            new_position_id += 1\n            model_kwargs[\"position_ids\"] = torch.cat(\n                [position_ids, new_position_id], dim=-1\n            )\n\n        model_kwargs[\"is_first_forward\"] = False\n        return model_kwargs\n\n    def forward(\n            self,\n            input_ids: Optional[torch.Tensor] = None,\n            position_ids: Optional[torch.Tensor] = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            full_attention_mask: Optional[torch.Tensor] = None,\n            past_key_values: Optional[Tuple[torch.FloatTensor]] = None,\n            inputs_embeds: Optional[torch.Tensor] = None,\n            labels: Optional[torch.Tensor] = None,\n            use_cache: Optional[bool] = None,\n            output_attentions: Optional[bool] = None,\n            output_hidden_states: Optional[bool] = None,\n            return_dict: Optional[bool] = None,\n            return_last_logit: Optional[bool] = False,\n    ):\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        transformer_outputs = self.transformer(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            attention_mask=attention_mask,\n            full_attention_mask=full_attention_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        hidden_states = transformer_outputs[0]\n        if return_last_logit:\n            hidden_states = hidden_states[-1:]\n        lm_logits = self.transformer.output_layer(hidden_states)\n        lm_logits = lm_logits.transpose(0, 1).contiguous()\n\n        loss = None\n        if labels is not None:\n            lm_logits = lm_logits.to(torch.float32)\n\n            # Shift so that tokens < n predict n\n            shift_logits = lm_logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss(ignore_index=self.config.pad_token_id)\n            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))\n\n            lm_logits = lm_logits.to(hidden_states.dtype)\n            loss = loss.to(hidden_states.dtype)\n\n        if not return_dict:\n            output = (lm_logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=lm_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n\n    def prepare_inputs_for_generation(\n            self,\n            input_ids: torch.LongTensor,\n            past_key_values: Optional[torch.Tensor] = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            position_ids: Optional[torch.Tensor] = None,\n            is_first_forward: bool = True,\n            **kwargs\n    ) -> dict:\n        # only last token for input_ids if past is not None\n        if position_ids is None:\n            position_ids = self.get_position_ids(input_ids, device=input_ids.device)\n        if not is_first_forward:\n            position_ids = position_ids[..., -1:]\n            input_ids = input_ids[:, -1:]\n        return {\n            \"input_ids\": input_ids,\n            \"past_key_values\": past_key_values,\n            \"position_ids\": position_ids,\n            \"attention_mask\": attention_mask,\n            \"return_last_logit\": True\n        }\n\n    @staticmethod\n    def _reorder_cache(\n            past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor\n    ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:\n        \"\"\"\n        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or\n        [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct\n        beam_idx at every generation step.\n\n        Output shares the same memory storage as `past`.\n        \"\"\"\n        return tuple(\n            (\n                layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),\n                layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),\n            )\n            for layer_past in past\n        )\n\n    def process_response(self, response):\n        response = response.strip()\n        response = response.replace(\"[[训练时间]]\", \"2023年\")\n        return response\n\n    def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):\n        prompt = \"\"\n        for i, (old_query, response) in enumerate(history):\n            prompt += \"[Round {}]\\n\\n问：{}\\n\\n答：{}\\n\\n\".format(i + 1, old_query, response)\n        prompt += \"[Round {}]\\n\\n问：{}\\n\\n答：\".format(len(history) + 1, query)\n        inputs = tokenizer([prompt], return_tensors=\"pt\")\n        inputs = inputs.to(self.device)\n        return inputs\n\n    def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):\n        if history:\n            prompt = \"\\n\\n[Round {}]\\n\\n问：{}\\n\\n答：\".format(len(history) + 1, query)\n            input_ids = tokenizer.encode(prompt, add_special_tokens=False)\n            input_ids = input_ids[1:]\n            inputs = tokenizer.batch_encode_plus([(input_ids, None)], return_tensors=\"pt\", add_special_tokens=False)\n        else:\n            prompt = \"[Round {}]\\n\\n问：{}\\n\\n答：\".format(len(history) + 1, query)\n            inputs = tokenizer([prompt], return_tensors=\"pt\")\n        inputs = inputs.to(self.device)\n        return inputs\n\n\n    @torch.no_grad()\n    def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,\n             do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):\n        if history is None:\n            history = []\n        if logits_processor is None:\n            logits_processor = LogitsProcessorList()\n        logits_processor.append(InvalidScoreLogitsProcessor())\n        gen_kwargs = {\"max_length\": max_length, \"num_beams\": num_beams, \"do_sample\": do_sample, \"top_p\": top_p,\n                      \"temperature\": temperature, \"logits_processor\": logits_processor, **kwargs}\n        inputs = self.build_inputs(tokenizer, query, history=history)\n        outputs = self.generate(**inputs, **gen_kwargs)\n        outputs = outputs.tolist()[0][len(inputs[\"input_ids\"][0]):]\n        response = tokenizer.decode(outputs)\n        response = self.process_response(response)\n        history = history + [(query, response)]\n        return response, history\n\n    @torch.no_grad()\n    def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values=None,\n                    max_length: int = 2048, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,\n                    return_past_key_values=False, **kwargs):\n        if history is None:\n            history = []\n        if logits_processor is None:\n            logits_processor = LogitsProcessorList()\n        logits_processor.append(InvalidScoreLogitsProcessor())\n        gen_kwargs = {\"max_length\": max_length, \"do_sample\": do_sample, \"top_p\": top_p,\n                      \"temperature\": temperature, \"logits_processor\": logits_processor, **kwargs}\n        if past_key_values is None and not return_past_key_values:\n            inputs = self.build_inputs(tokenizer, query, history=history)\n        else:\n            inputs = self.build_stream_inputs(tokenizer, query, history=history)\n        if past_key_values is not None:\n            past_length = past_key_values[0][0].shape[0]\n            inputs.position_ids += past_length\n            attention_mask = inputs.attention_mask\n            attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)\n            inputs['attention_mask'] = attention_mask\n        for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,\n                                            return_past_key_values=return_past_key_values, **gen_kwargs):\n            if return_past_key_values:\n                outputs, past_key_values = outputs\n            outputs = outputs.tolist()[0][len(inputs[\"input_ids\"][0]):]\n            response = tokenizer.decode(outputs)\n            response = self.process_response(response)\n            new_history = history + [(query, response)]\n            if return_past_key_values:\n                yield response, new_history, past_key_values\n            else:\n                yield response, new_history\n\n    @torch.no_grad()\n    def stream_generate(\n            self,\n            input_ids,\n            generation_config: Optional[GenerationConfig] = None,\n            logits_processor: Optional[LogitsProcessorList] = None,\n            stopping_criteria: Optional[StoppingCriteriaList] = None,\n            prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,\n            return_past_key_values=False,\n            **kwargs,\n    ):\n        batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]\n\n        if generation_config is None:\n            generation_config = self.generation_config\n        generation_config = copy.deepcopy(generation_config)\n        model_kwargs = generation_config.update(**kwargs)\n        bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id\n\n        if isinstance(eos_token_id, int):\n            eos_token_id = [eos_token_id]\n\n        has_default_max_length = kwargs.get(\"max_length\") is None and generation_config.max_length is not None\n        if has_default_max_length and generation_config.max_new_tokens is None:\n            warnings.warn(\n                f\"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. \"\n                \"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we\"\n                \" recommend using `max_new_tokens` to control the maximum length of the generation.\",\n                UserWarning,\n            )\n        elif generation_config.max_new_tokens is not None:\n            generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length\n            if not has_default_max_length:\n                logger.warn(\n                    f\"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=\"\n                    f\"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. \"\n                    \"Please refer to the documentation for more information. \"\n                    \"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)\",\n                    UserWarning,\n                )\n\n        if input_ids_seq_length >= generation_config.max_length:\n            input_ids_string = \"decoder_input_ids\" if self.config.is_encoder_decoder else \"input_ids\"\n            logger.warning(\n                f\"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to\"\n                f\" {generation_config.max_length}. This can lead to unexpected behavior. You should consider\"\n                \" increasing `max_new_tokens`.\"\n            )\n\n        # 2. Set generation parameters if not already defined\n        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()\n        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()\n\n        logits_processor = self._get_logits_processor(\n            generation_config=generation_config,\n            input_ids_seq_length=input_ids_seq_length,\n            encoder_input_ids=input_ids,\n            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,\n            logits_processor=logits_processor,\n        )\n\n        stopping_criteria = self._get_stopping_criteria(\n            generation_config=generation_config, stopping_criteria=stopping_criteria\n        )\n        logits_warper = self._get_logits_warper(generation_config)\n\n        unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)\n        scores = None\n        while True:\n            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)\n            # forward pass to get next token\n            outputs = self(\n                **model_inputs,\n                return_dict=True,\n                output_attentions=False,\n                output_hidden_states=False,\n            )\n\n            next_token_logits = outputs.logits[:, -1, :]\n\n            # pre-process distribution\n            next_token_scores = logits_processor(input_ids, next_token_logits)\n            next_token_scores = logits_warper(input_ids, next_token_scores)\n\n            # sample\n            probs = nn.functional.softmax(next_token_scores, dim=-1)\n            if generation_config.do_sample:\n                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)\n            else:\n                next_tokens = torch.argmax(probs, dim=-1)\n\n            # update generated ids, model inputs, and length for next step\n            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)\n            model_kwargs = self._update_model_kwargs_for_generation(\n                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder\n            )\n            unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())\n            if return_past_key_values:\n                yield input_ids, outputs.past_key_values\n            else:\n                yield input_ids\n            # stop when each sentence is finished, or if we exceed the maximum length\n            if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):\n                break\n\n    def quantize(self, bits: int, empty_init=False, device=None, **kwargs):\n        if bits == 0:\n            return\n\n        from .quantization import quantize\n\n        if self.quantized:\n            logger.info(\"Already quantized.\")\n            return self\n\n        self.quantized = True\n\n        self.config.quantization_bit = bits\n\n        self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device,\n                                            **kwargs)\n        return self\n"
  },
  {
    "path": "src/resources/models/chatglm2-6B/quantization.py",
    "content": "from torch.nn import Linear\nfrom torch.nn.parameter import Parameter\n\nimport bz2\nimport torch\nimport base64\nimport ctypes\nfrom transformers.utils import logging\n\nfrom typing import List\nfrom functools import partial\n\nlogger = logging.get_logger(__name__)\n\ntry:\n    from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up\n\n    class Kernel:\n        def __init__(self, code: bytes, function_names: List[str]):\n            self.code = code\n            self._function_names = function_names\n            self._cmodule = LazyKernelCModule(self.code)\n\n            for name in self._function_names:\n                setattr(self, name, KernelFunction(self._cmodule, name))\n\n    quantization_code = \"$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ\"\n\n    kernels = Kernel(\n        bz2.decompress(base64.b64decode(quantization_code)),\n        [\n            \"int4WeightCompression\",\n            \"int4WeightExtractionFloat\",\n            \"int4WeightExtractionHalf\",\n            \"int8WeightExtractionFloat\",\n            \"int8WeightExtractionHalf\",\n        ],\n    )\nexcept Exception as exception:\n    kernels = None\n    logger.warning(\"Failed to load cpm_kernels:\" + str(exception))\n\n\nclass W8A16Linear(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):\n        ctx.inp_shape = inp.size()\n        ctx.weight_bit_width = weight_bit_width\n        out_features = quant_w.size(0)\n        inp = inp.contiguous().view(-1, inp.size(-1))\n        weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)\n        ctx.weight_shape = weight.size()\n        output = inp.mm(weight.t())\n        ctx.save_for_backward(inp, quant_w, scale_w)\n        return output.view(*(ctx.inp_shape[:-1] + (out_features,)))\n\n    @staticmethod\n    def backward(ctx, grad_output: torch.Tensor):\n        inp, quant_w, scale_w = ctx.saved_tensors\n        weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width)\n        grad_output = grad_output.contiguous().view(-1, weight.size(0))\n        grad_input = grad_output.mm(weight)\n        grad_weight = grad_output.t().mm(inp)\n        return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None\n\n\ndef compress_int4_weight(weight: torch.Tensor):  # (n, m)\n    with torch.cuda.device(weight.device):\n        n, m = weight.size(0), weight.size(1)\n        assert m % 2 == 0\n        m = m // 2\n        out = torch.empty(n, m, dtype=torch.int8, device=\"cuda\")\n        stream = torch.cuda.current_stream()\n\n        gridDim = (n, 1, 1)\n        blockDim = (min(round_up(m, 32), 1024), 1, 1)\n\n        kernels.int4WeightCompression(\n            gridDim,\n            blockDim,\n            0,\n            stream,\n            [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],\n        )\n        return out\n\n\ndef extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):\n    assert scale_list.dtype in [torch.half, torch.bfloat16]\n    assert weight.dtype in [torch.int8]\n    if source_bit_width == 8:\n        return weight.to(scale_list.dtype) * scale_list[:, None]\n    elif source_bit_width == 4:\n        func = (\n            kernels.int4WeightExtractionHalf if scale_list.dtype == torch.half else kernels.int4WeightExtractionBFloat16\n        )\n    else:\n        assert False, \"Unsupported bit-width\"\n\n    with torch.cuda.device(weight.device):\n        n, m = weight.size(0), weight.size(1)\n        out = torch.empty(n, m * (8 // source_bit_width), dtype=scale_list.dtype, device=\"cuda\")\n        stream = torch.cuda.current_stream()\n\n        gridDim = (n, 1, 1)\n        blockDim = (min(round_up(m, 32), 1024), 1, 1)\n\n        func(\n            gridDim,\n            blockDim,\n            0,\n            stream,\n            [\n                ctypes.c_void_p(weight.data_ptr()),\n                ctypes.c_void_p(scale_list.data_ptr()),\n                ctypes.c_void_p(out.data_ptr()),\n                ctypes.c_int32(n),\n                ctypes.c_int32(m),\n            ],\n        )\n        return out\n\n\nclass QuantizedLinear(torch.nn.Module):\n    def __init__(self, weight_bit_width: int, weight, bias=None, device=\"cpu\", dtype=None, empty_init=False, *args,\n                 **kwargs):\n        super().__init__()\n        self.weight_bit_width = weight_bit_width\n\n        shape = weight.shape\n\n        if weight is None or empty_init:\n            self.weight = torch.empty(shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=device)\n            self.weight_scale = torch.empty(shape[0], dtype=dtype, device=device)\n        else:\n            self.weight_scale = weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)\n            self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)\n            if weight_bit_width == 4:\n                self.weight = compress_int4_weight(self.weight)\n\n        self.weight = Parameter(self.weight.to(device), requires_grad=False)\n        self.weight_scale = Parameter(self.weight_scale.to(device), requires_grad=False)\n        self.bias = Parameter(bias.to(device), requires_grad=False) if bias is not None else None\n\n    def forward(self, input):\n        output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width)\n        if self.bias is not None:\n            output = output + self.bias\n        return output\n\n\ndef quantize(model, weight_bit_width, empty_init=False, device=None):\n    \"\"\"Replace fp16 linear with quantized linear\"\"\"\n    for layer in model.layers:\n        layer.self_attention.query_key_value = QuantizedLinear(\n            weight_bit_width=weight_bit_width,\n            weight=layer.self_attention.query_key_value.weight.to(torch.cuda.current_device()),\n            bias=layer.self_attention.query_key_value.bias,\n            dtype=layer.self_attention.query_key_value.weight.dtype,\n            device=layer.self_attention.query_key_value.weight.device if device is None else device,\n            empty_init=empty_init\n        )\n        layer.self_attention.dense = QuantizedLinear(\n            weight_bit_width=weight_bit_width,\n            weight=layer.self_attention.dense.weight.to(torch.cuda.current_device()),\n            bias=layer.self_attention.dense.bias,\n            dtype=layer.self_attention.dense.weight.dtype,\n            device=layer.self_attention.dense.weight.device if device is None else device,\n            empty_init=empty_init\n        )\n        layer.mlp.dense_h_to_4h = QuantizedLinear(\n            weight_bit_width=weight_bit_width,\n            weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),\n            bias=layer.mlp.dense_h_to_4h.bias,\n            dtype=layer.mlp.dense_h_to_4h.weight.dtype,\n            device=layer.mlp.dense_h_to_4h.weight.device if device is None else device,\n            empty_init=empty_init\n        )\n        layer.mlp.dense_4h_to_h = QuantizedLinear(\n            weight_bit_width=weight_bit_width,\n            weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),\n            bias=layer.mlp.dense_4h_to_h.bias,\n            dtype=layer.mlp.dense_4h_to_h.weight.dtype,\n            device=layer.mlp.dense_4h_to_h.weight.device if device is None else device,\n            empty_init=empty_init\n        )\n\n    return model\n"
  },
  {
    "path": "src/resources/models/chatglm2-6B/tokenization_chatglm.py",
    "content": "import os\nimport torch\nfrom typing import List, Optional, Union, Dict\nfrom sentencepiece import SentencePieceProcessor\nfrom transformers import PreTrainedTokenizer\nfrom transformers.utils import logging, PaddingStrategy\nfrom transformers.tokenization_utils_base import EncodedInput, BatchEncoding\n\n\nclass SPTokenizer:\n    def __init__(self, model_path: str):\n        # reload tokenizer\n        assert os.path.isfile(model_path), model_path\n        self.sp_model = SentencePieceProcessor(model_file=model_path)\n\n        # BOS / EOS token IDs\n        self.n_words: int = self.sp_model.vocab_size()\n        self.bos_id: int = self.sp_model.bos_id()\n        self.eos_id: int = self.sp_model.eos_id()\n        self.pad_id: int = self.sp_model.eos_id()\n        assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()\n\n        special_tokens = [\"[MASK]\", \"[gMASK]\", \"[sMASK]\", \"sop\", \"eop\"]\n        self.special_tokens = {}\n        self.index_special_tokens = {}\n        for token in special_tokens:\n            self.special_tokens[token] = self.n_words\n            self.index_special_tokens[self.n_words] = token\n            self.n_words += 1\n\n    def tokenize(self, s: str):\n        return self.sp_model.EncodeAsPieces(s)\n\n    def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:\n        assert type(s) is str\n        t = self.sp_model.encode(s)\n        if bos:\n            t = [self.bos_id] + t\n        if eos:\n            t = t + [self.eos_id]\n        return t\n\n    def decode(self, t: List[int]) -> str:\n        return self.sp_model.decode(t)\n\n    def decode_tokens(self, tokens: List[str]) -> str:\n        text = self.sp_model.DecodePieces(tokens)\n        return text\n\n    def convert_token_to_id(self, token):\n        \"\"\" Converts a token (str) in an id using the vocab. \"\"\"\n        if token in self.special_tokens:\n            return self.special_tokens[token]\n        return self.sp_model.PieceToId(token)\n\n    def convert_id_to_token(self, index):\n        \"\"\"Converts an index (integer) in a token (str) using the vocab.\"\"\"\n        if index in self.index_special_tokens:\n            return \"\"\n        return self.sp_model.IdToPiece(index)\n\n\nclass ChatGLMTokenizer(PreTrainedTokenizer):\n    vocab_files_names = {\"vocab_file\": \"tokenizer.model\"}\n\n    model_input_names = [\"input_ids\", \"attention_mask\", \"position_ids\"]\n\n    def __init__(self, vocab_file, padding_side=\"right\", **kwargs):\n        super().__init__(padding_side=padding_side, **kwargs)\n        self.name = \"GLMTokenizer\"\n\n        self.vocab_file = vocab_file\n        self.tokenizer = SPTokenizer(vocab_file)\n        self.special_tokens = {\n            \"<bos>\": self.tokenizer.bos_id,\n            \"<eos>\": self.tokenizer.eos_id,\n            \"<pad>\": self.tokenizer.pad_id\n        }\n\n    def get_command(self, token):\n        if token in self.special_tokens:\n            return self.special_tokens[token]\n        assert token in self.tokenizer.special_tokens, f\"{token} is not a special token for {self.name}\"\n        return self.tokenizer.special_tokens[token]\n\n    @property\n    def pad_token(self) -> str:\n        return \"</s>\"\n\n    @property\n    def pad_token_id(self):\n        return self.get_command(\"<pad>\")\n\n    @property\n    def vocab_size(self):\n        return self.tokenizer.n_words\n\n    def get_vocab(self):\n        \"\"\" Returns vocab as a dict \"\"\"\n        vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}\n        vocab.update(self.added_tokens_encoder)\n        return vocab\n\n    def _tokenize(self, text, **kwargs):\n        return self.tokenizer.tokenize(text)\n\n    def _convert_token_to_id(self, token):\n        \"\"\" Converts a token (str) in an id using the vocab. \"\"\"\n        return self.tokenizer.convert_token_to_id(token)\n\n    def _convert_id_to_token(self, index):\n        \"\"\"Converts an index (integer) in a token (str) using the vocab.\"\"\"\n        return self.tokenizer.convert_id_to_token(index)\n\n    def convert_tokens_to_string(self, tokens: List[str]) -> str:\n        current_sub_tokens = []\n        out_string = \"\"\n        prev_is_special = False\n        for i, token in enumerate(tokens):\n            # make sure that special tokens are not decoded using sentencepiece model\n            if token in self.all_special_tokens:\n                if not prev_is_special and i != 0:\n                    out_string += \" \"\n                out_string += self.tokenizer.decode_tokens(current_sub_tokens) + token\n                prev_is_special = True\n                current_sub_tokens = []\n            else:\n                current_sub_tokens.append(token)\n                prev_is_special = False\n        out_string += self.tokenizer.decode_tokens(current_sub_tokens)\n        return out_string\n\n    def save_vocabulary(self, save_directory, filename_prefix=None):\n        \"\"\"\n        Save the vocabulary and special tokens file to a directory.\n\n        Args:\n            save_directory (`str`):\n                The directory in which to save the vocabulary.\n            filename_prefix (`str`, *optional*):\n                An optional prefix to add to the named of the saved files.\n\n        Returns:\n            `Tuple(str)`: Paths to the files saved.\n        \"\"\"\n        if os.path.isdir(save_directory):\n            vocab_file = os.path.join(\n                save_directory, self.vocab_files_names[\"vocab_file\"]\n            )\n        else:\n            vocab_file = save_directory\n\n        with open(self.vocab_file, 'rb') as fin:\n            proto_str = fin.read()\n\n        with open(vocab_file, \"wb\") as writer:\n            writer.write(proto_str)\n\n        return (vocab_file,)\n\n    def get_prefix_tokens(self):\n        prefix_tokens = [self.get_command(\"[gMASK]\"), self.get_command(\"sop\")]\n        return prefix_tokens\n\n    def build_inputs_with_special_tokens(\n            self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None\n    ) -> List[int]:\n        \"\"\"\n        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and\n        adding special tokens. A BERT sequence has the following format:\n\n        - single sequence: `[CLS] X [SEP]`\n        - pair of sequences: `[CLS] A [SEP] B [SEP]`\n\n        Args:\n            token_ids_0 (`List[int]`):\n                List of IDs to which the special tokens will be added.\n            token_ids_1 (`List[int]`, *optional*):\n                Optional second list of IDs for sequence pairs.\n\n        Returns:\n            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.\n        \"\"\"\n        prefix_tokens = self.get_prefix_tokens()\n        token_ids_0 = prefix_tokens + token_ids_0\n        if token_ids_1 is not None:\n            # token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command(\"<eos>\")]\n            token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command(\"eop\")]\n        return token_ids_0\n\n    def _pad(\n            self,\n            encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],\n            max_length: Optional[int] = None,\n            padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,\n            pad_to_multiple_of: Optional[int] = None,\n            return_attention_mask: Optional[bool] = None,\n    ) -> dict:\n        \"\"\"\n        Pad encoded inputs (on left/right and up to predefined length or max length in the batch)\n\n        Args:\n            encoded_inputs:\n                Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).\n            max_length: maximum length of the returned list and optionally padding length (see below).\n                Will truncate by taking into account the special tokens.\n            padding_strategy: PaddingStrategy to use for padding.\n\n                - PaddingStrategy.LONGEST Pad to the longest sequence in the batch\n                - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)\n                - PaddingStrategy.DO_NOT_PAD: Do not pad\n                The tokenizer padding sides are defined in self.padding_side:\n\n                    - 'left': pads on the left of the sequences\n                    - 'right': pads on the right of the sequences\n            pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.\n                This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability\n                `>= 7.5` (Volta).\n            return_attention_mask:\n                (optional) Set to False to avoid returning attention mask (default: set to model specifics)\n        \"\"\"\n        # Load from model defaults\n        # assert self.padding_side == \"left\"\n\n        required_input = encoded_inputs[self.model_input_names[0]]\n        seq_length = len(required_input)\n\n        if padding_strategy == PaddingStrategy.LONGEST:\n            max_length = len(required_input)\n\n        if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):\n            max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of\n\n        needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length\n\n        # Initialize attention mask if not present.\n        if \"attention_mask\" not in encoded_inputs:\n            encoded_inputs[\"attention_mask\"] = [1] * seq_length\n\n        if \"position_ids\" not in encoded_inputs:\n            encoded_inputs[\"position_ids\"] = list(range(seq_length))\n\n        if needs_to_be_padded:\n            difference = max_length - len(required_input)\n\n            if self.padding_side == \"left\":\n                if \"attention_mask\" in encoded_inputs:\n                    encoded_inputs[\"attention_mask\"] = [0] * difference + encoded_inputs[\"attention_mask\"]\n                if \"position_ids\" in encoded_inputs:\n                    encoded_inputs[\"position_ids\"] = [0] * difference + encoded_inputs[\"position_ids\"]\n                encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input\n            else:\n                if \"attention_mask\" in encoded_inputs:\n                    encoded_inputs[\"attention_mask\"] = encoded_inputs[\"attention_mask\"] + [0] * difference\n                if \"position_ids\" in encoded_inputs:\n                    encoded_inputs[\"position_ids\"] = encoded_inputs[\"position_ids\"] + [0] * difference\n                encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference\n\n        return encoded_inputs\n"
  },
  {
    "path": "src/resources/models/chatglm2-6B/tokenizer_config.json",
    "content": "{\n  \"name_or_path\": \"THUDM/chatglm-6b\",\n  \"remove_space\": false,\n  \"do_lower_case\": false,\n  \"tokenizer_class\": \"ChatGLMTokenizer\",\n  \"auto_map\": {\n    \"AutoTokenizer\": [\n      \"tokenization_chatglm.ChatGLMTokenizer\",\n      null\n      ]\n  }\n}\n"
  },
  {
    "path": "src/resources/models/chatglm3-6B/config.json",
    "content": "{\n  \"_name_or_path\": \"THUDM/chatglm3-6b\",\n  \"model_type\": \"chatglm\",\n  \"architectures\": [\n    \"ChatGLMModel\"\n  ],\n  \"auto_map\": {\n    \"AutoConfig\": \"configuration_chatglm.ChatGLMConfig\",\n    \"AutoModel\": \"modeling_chatglm.ChatGLMForConditionalGeneration\",\n    \"AutoModelForCausalLM\": \"modeling_chatglm.ChatGLMForConditionalGeneration\",\n    \"AutoModelForSeq2SeqLM\": \"modeling_chatglm.ChatGLMForConditionalGeneration\",\n    \"AutoModelForSequenceClassification\": \"modeling_chatglm.ChatGLMForSequenceClassification\"\n  },\n  \"add_bias_linear\": false,\n  \"add_qkv_bias\": true,\n  \"apply_query_key_layer_scaling\": true,\n  \"apply_residual_connection_post_layernorm\": false,\n  \"attention_dropout\": 0.0,\n  \"attention_softmax_in_fp32\": true,\n  \"bias_dropout_fusion\": true,\n  \"ffn_hidden_size\": 13696,\n  \"fp32_residual_connection\": false,\n  \"hidden_dropout\": 0.0,\n  \"hidden_size\": 4096,\n  \"kv_channels\": 128,\n  \"layernorm_epsilon\": 1e-05,\n  \"multi_query_attention\": true,\n  \"multi_query_group_num\": 2,\n  \"num_attention_heads\": 32,\n  \"num_layers\": 28,\n  \"original_rope\": true,\n  \"padded_vocab_size\": 65024,\n  \"post_layer_norm\": true,\n  \"rmsnorm\": true,\n  \"seq_length\": 8192,\n  \"use_cache\": true,\n  \"torch_dtype\": \"float16\",\n  \"transformers_version\": \"4.30.2\",\n  \"tie_word_embeddings\": false,\n  \"eos_token_id\": 2,\n  \"pad_token_id\": -100\n}"
  },
  {
    "path": "src/resources/models/chatglm3-6B/configuration_chatglm.py",
    "content": "from transformers import PretrainedConfig\n\n\nclass ChatGLMConfig(PretrainedConfig):\n    model_type = \"chatglm\"\n    def __init__(\n        self,\n        num_layers=28,\n        padded_vocab_size=65024,\n        hidden_size=4096,\n        ffn_hidden_size=13696,\n        kv_channels=128,\n        num_attention_heads=32,\n        seq_length=2048,\n        hidden_dropout=0.0,\n        classifier_dropout=None,\n        attention_dropout=0.0,\n        layernorm_epsilon=1e-5,\n        rmsnorm=True,\n        apply_residual_connection_post_layernorm=False,\n        post_layer_norm=True,\n        add_bias_linear=False,\n        add_qkv_bias=False,\n        bias_dropout_fusion=True,\n        multi_query_attention=False,\n        multi_query_group_num=1,\n        apply_query_key_layer_scaling=True,\n        attention_softmax_in_fp32=True,\n        fp32_residual_connection=False,\n        quantization_bit=0,\n        pre_seq_len=None,\n        prefix_projection=False,\n        **kwargs\n    ):\n        self.num_layers = num_layers\n        self.vocab_size = padded_vocab_size\n        self.padded_vocab_size = padded_vocab_size\n        self.hidden_size = hidden_size\n        self.ffn_hidden_size = ffn_hidden_size\n        self.kv_channels = kv_channels\n        self.num_attention_heads = num_attention_heads\n        self.seq_length = seq_length\n        self.hidden_dropout = hidden_dropout\n        self.classifier_dropout = classifier_dropout\n        self.attention_dropout = attention_dropout\n        self.layernorm_epsilon = layernorm_epsilon\n        self.rmsnorm = rmsnorm\n        self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm\n        self.post_layer_norm = post_layer_norm\n        self.add_bias_linear = add_bias_linear\n        self.add_qkv_bias = add_qkv_bias\n        self.bias_dropout_fusion = bias_dropout_fusion\n        self.multi_query_attention = multi_query_attention\n        self.multi_query_group_num = multi_query_group_num\n        self.apply_query_key_layer_scaling = apply_query_key_layer_scaling\n        self.attention_softmax_in_fp32 = attention_softmax_in_fp32\n        self.fp32_residual_connection = fp32_residual_connection\n        self.quantization_bit = quantization_bit\n        self.pre_seq_len = pre_seq_len\n        self.prefix_projection = prefix_projection\n        super().__init__(**kwargs)"
  },
  {
    "path": "src/resources/models/chatglm3-6B/modeling_chatglm.py",
    "content": "\"\"\" PyTorch ChatGLM model. \"\"\"\n\nimport math\nimport copy\nimport warnings\nimport re\nimport sys\n\nimport torch\nimport torch.utils.checkpoint\nimport torch.nn.functional as F\nfrom torch import nn\nfrom torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss\nfrom torch.nn.utils import skip_init\nfrom typing import Optional, Tuple, Union, List, Callable, Dict, Any\nfrom copy import deepcopy\n\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n    SequenceClassifierOutputWithPast,\n)\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.utils import logging\nfrom transformers.generation.logits_process import LogitsProcessor\nfrom transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput\n\nfrom .configuration_chatglm import ChatGLMConfig\n\n# flags required to enable jit fusion kernels\n\nif sys.platform != 'darwin':\n    torch._C._jit_set_profiling_mode(False)\n    torch._C._jit_set_profiling_executor(False)\n    torch._C._jit_override_can_fuse_on_cpu(True)\n    torch._C._jit_override_can_fuse_on_gpu(True)\n\nlogger = logging.get_logger(__name__)\n\n_CHECKPOINT_FOR_DOC = \"THUDM/ChatGLM\"\n_CONFIG_FOR_DOC = \"ChatGLMConfig\"\n\nCHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [\n    \"THUDM/chatglm3-6b\",\n    # See all ChatGLM models at https://huggingface.co/models?filter=chatglm\n]\n\n\ndef default_init(cls, *args, **kwargs):\n    return cls(*args, **kwargs)\n\n\nclass InvalidScoreLogitsProcessor(LogitsProcessor):\n    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:\n        if torch.isnan(scores).any() or torch.isinf(scores).any():\n            scores.zero_()\n            scores[..., 5] = 5e4\n        return scores\n\n\nclass PrefixEncoder(torch.nn.Module):\n    \"\"\"\n    The torch.nn model to encode the prefix\n    Input shape: (batch-size, prefix-length)\n    Output shape: (batch-size, prefix-length, 2*layers*hidden)\n    \"\"\"\n\n    def __init__(self, config: ChatGLMConfig):\n        super().__init__()\n        self.prefix_projection = config.prefix_projection\n        if self.prefix_projection:\n            # Use a two-layer MLP to encode the prefix\n            kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2\n            self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)\n            self.trans = torch.nn.Sequential(\n                torch.nn.Linear(kv_size, config.hidden_size),\n                torch.nn.Tanh(),\n                torch.nn.Linear(config.hidden_size, kv_size)\n            )\n        else:\n            self.embedding = torch.nn.Embedding(config.pre_seq_len,\n                                                config.num_layers * config.kv_channels * config.multi_query_group_num * 2)\n\n    def forward(self, prefix: torch.Tensor):\n        if self.prefix_projection:\n            prefix_tokens = self.embedding(prefix)\n            past_key_values = self.trans(prefix_tokens)\n        else:\n            past_key_values = self.embedding(prefix)\n        return past_key_values\n\n\ndef split_tensor_along_last_dim(\n        tensor: torch.Tensor,\n        num_partitions: int,\n        contiguous_split_chunks: bool = False,\n) -> List[torch.Tensor]:\n    \"\"\"Split a tensor along its last dimension.\n\n    Arguments:\n        tensor: input tensor.\n        num_partitions: number of partitions to split the tensor\n        contiguous_split_chunks: If True, make each chunk contiguous\n                                 in memory.\n\n    Returns:\n        A list of Tensors\n    \"\"\"\n    # Get the size and dimension.\n    last_dim = tensor.dim() - 1\n    last_dim_size = tensor.size()[last_dim] // num_partitions\n    # Split.\n    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)\n    # Note: torch.split does not create contiguous tensors by default.\n    if contiguous_split_chunks:\n        return tuple(chunk.contiguous() for chunk in tensor_list)\n\n    return tensor_list\n\n\nclass RotaryEmbedding(nn.Module):\n    def __init__(self, dim, original_impl=False, device=None, dtype=None):\n        super().__init__()\n        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))\n        self.register_buffer(\"inv_freq\", inv_freq)\n        self.dim = dim\n        self.original_impl = original_impl\n\n    def forward_impl(\n            self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000\n    ):\n        \"\"\"Enhanced Transformer with Rotary Position Embedding.\n\n        Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/\n        transformers/rope/__init__.py. MIT License:\n        https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.\n        \"\"\"\n        # $\\Theta = {\\theta_i = 10000^{\\frac{2(i-1)}{d}}, i \\in [1, 2, ..., \\frac{d}{2}]}$\n        theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))\n\n        # Create position indexes `[0, 1, ..., seq_len - 1]`\n        seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)\n\n        # Calculate the product of position index and $\\theta_i$\n        idx_theta = torch.outer(seq_idx, theta).float()\n\n        cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)\n\n        # this is to mimic the behaviour of complex32, else we will get different results\n        if dtype in (torch.float16, torch.bfloat16, torch.int8):\n            cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()\n        return cache\n\n    def forward(self, max_seq_len, offset=0):\n        return self.forward_impl(\n            max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device\n        )\n\n\n@torch.jit.script\ndef apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:\n    # x: [sq, b, np, hn]\n    sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)\n    rot_dim = rope_cache.shape[-2] * 2\n    x, x_pass = x[..., :rot_dim], x[..., rot_dim:]\n    # truncate to support variable sizes\n    rope_cache = rope_cache[:sq]\n    xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)\n    rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)\n    x_out2 = torch.stack(\n        [\n            xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],\n            xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],\n        ],\n        -1,\n    )\n    x_out2 = x_out2.flatten(3)\n    return torch.cat((x_out2, x_pass), dim=-1)\n\n\nclass RMSNorm(torch.nn.Module):\n    def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):\n        super().__init__()\n        self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))\n        self.eps = eps\n\n    def forward(self, hidden_states: torch.Tensor):\n        input_dtype = hidden_states.dtype\n        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)\n\n        return (self.weight * hidden_states).to(input_dtype)\n\n\nclass CoreAttention(torch.nn.Module):\n    def __init__(self, config: ChatGLMConfig, layer_number):\n        super(CoreAttention, self).__init__()\n\n        self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling\n        self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32\n        if self.apply_query_key_layer_scaling:\n            self.attention_softmax_in_fp32 = True\n        self.layer_number = max(1, layer_number)\n\n        projection_size = config.kv_channels * config.num_attention_heads\n\n        # Per attention head and per partition values.\n        self.hidden_size_per_partition = projection_size\n        self.hidden_size_per_attention_head = projection_size // config.num_attention_heads\n        self.num_attention_heads_per_partition = config.num_attention_heads\n\n        coeff = None\n        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)\n        if self.apply_query_key_layer_scaling:\n            coeff = self.layer_number\n            self.norm_factor *= coeff\n        self.coeff = coeff\n\n        self.attention_dropout = torch.nn.Dropout(config.attention_dropout)\n\n    def forward(self, query_layer, key_layer, value_layer, attention_mask):\n        pytorch_major_version = int(torch.__version__.split('.')[0])\n        if pytorch_major_version >= 2:\n            query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]\n            if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:\n                context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,\n                                                                                 is_causal=True)\n            else:\n                if attention_mask is not None:\n                    attention_mask = ~attention_mask\n                context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,\n                                                                                 attention_mask)\n            context_layer = context_layer.permute(2, 0, 1, 3)\n            new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)\n            context_layer = context_layer.reshape(*new_context_layer_shape)\n        else:\n            # Raw attention scores\n\n            # [b, np, sq, sk]\n            output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))\n\n            # [sq, b, np, hn] -> [sq, b * np, hn]\n            query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)\n            # [sk, b, np, hn] -> [sk, b * np, hn]\n            key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)\n\n            # preallocting input tensor: [b * np, sq, sk]\n            matmul_input_buffer = torch.empty(\n                output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,\n                device=query_layer.device\n            )\n\n            # Raw attention scores. [b * np, sq, sk]\n            matmul_result = torch.baddbmm(\n                matmul_input_buffer,\n                query_layer.transpose(0, 1),  # [b * np, sq, hn]\n                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]\n                beta=0.0,\n                alpha=(1.0 / self.norm_factor),\n            )\n\n            # change view to [b, np, sq, sk]\n            attention_scores = matmul_result.view(*output_size)\n\n            # ===========================\n            # Attention probs and dropout\n            # ===========================\n\n            # attention scores and attention mask [b, np, sq, sk]\n            if self.attention_softmax_in_fp32:\n                attention_scores = attention_scores.float()\n            if self.coeff is not None:\n                attention_scores = attention_scores * self.coeff\n            if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:\n                attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],\n                                            device=attention_scores.device, dtype=torch.bool)\n                attention_mask.tril_()\n                attention_mask = ~attention_mask\n            if attention_mask is not None:\n                attention_scores = attention_scores.masked_fill(attention_mask, float(\"-inf\"))\n            attention_probs = F.softmax(attention_scores, dim=-1)\n            attention_probs = attention_probs.type_as(value_layer)\n\n            # This is actually dropping out entire tokens to attend to, which might\n            # seem a bit unusual, but is taken from the original Transformer paper.\n            attention_probs = self.attention_dropout(attention_probs)\n            # =========================\n            # Context layer. [sq, b, hp]\n            # =========================\n\n            # value_layer -> context layer.\n            # [sk, b, np, hn] --> [b, np, sq, hn]\n\n            # context layer shape: [b, np, sq, hn]\n            output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))\n            # change view [sk, b * np, hn]\n            value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)\n            # change view [b * np, sq, sk]\n            attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)\n            # matmul: [b * np, sq, hn]\n            context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))\n            # change view [b, np, sq, hn]\n            context_layer = context_layer.view(*output_size)\n            # [b, np, sq, hn] --> [sq, b, np, hn]\n            context_layer = context_layer.permute(2, 0, 1, 3).contiguous()\n            # [sq, b, np, hn] --> [sq, b, hp]\n            new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)\n            context_layer = context_layer.view(*new_context_layer_shape)\n\n        return context_layer\n\n\nclass SelfAttention(torch.nn.Module):\n    \"\"\"Parallel self-attention layer abstract class.\n\n    Self-attention layer takes input with size [s, b, h]\n    and returns output of the same size.\n    \"\"\"\n\n    def __init__(self, config: ChatGLMConfig, layer_number, device=None):\n        super(SelfAttention, self).__init__()\n        self.layer_number = max(1, layer_number)\n\n        self.projection_size = config.kv_channels * config.num_attention_heads\n\n        # Per attention head and per partition values.\n        self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads\n        self.num_attention_heads_per_partition = config.num_attention_heads\n\n        self.multi_query_attention = config.multi_query_attention\n        self.qkv_hidden_size = 3 * self.projection_size\n        if self.multi_query_attention:\n            self.num_multi_query_groups_per_partition = config.multi_query_group_num\n            self.qkv_hidden_size = (\n                    self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num\n            )\n        self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,\n                                         bias=config.add_bias_linear or config.add_qkv_bias,\n                                         device=device, **_config_to_kwargs(config)\n                                         )\n\n        self.core_attention = CoreAttention(config, self.layer_number)\n\n        # Output.\n        self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,\n                               device=device, **_config_to_kwargs(config)\n                               )\n\n    def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):\n        if self.multi_query_attention:\n            num_attention_heads = self.num_multi_query_groups_per_partition\n        else:\n            num_attention_heads = self.num_attention_heads_per_partition\n        return torch.empty(\n            inference_max_sequence_len,\n            batch_size,\n            num_attention_heads,\n            self.hidden_size_per_attention_head,\n            dtype=dtype,\n            device=device,\n        )\n\n    def forward(\n            self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True\n    ):\n        # hidden_states: [sq, b, h]\n\n        # =================================================\n        # Pre-allocate memory for key-values for inference.\n        # =================================================\n        # =====================\n        # Query, Key, and Value\n        # =====================\n\n        # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]\n        mixed_x_layer = self.query_key_value(hidden_states)\n\n        if self.multi_query_attention:\n            (query_layer, key_layer, value_layer) = mixed_x_layer.split(\n                [\n                    self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,\n                    self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,\n                    self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,\n                ],\n                dim=-1,\n            )\n            query_layer = query_layer.view(\n                query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)\n            )\n            key_layer = key_layer.view(\n                key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)\n            )\n            value_layer = value_layer.view(\n                value_layer.size()[:-1]\n                + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)\n            )\n        else:\n            new_tensor_shape = mixed_x_layer.size()[:-1] + \\\n                               (self.num_attention_heads_per_partition,\n                                3 * self.hidden_size_per_attention_head)\n            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)\n\n            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]\n            (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)\n\n        # apply relative positional encoding (rotary embedding)\n        if rotary_pos_emb is not None:\n            query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)\n            key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)\n\n        # adjust key and value for inference\n        if kv_cache is not None:\n            cache_k, cache_v = kv_cache\n            key_layer = torch.cat((cache_k, key_layer), dim=0)\n            value_layer = torch.cat((cache_v, value_layer), dim=0)\n        if use_cache:\n            kv_cache = (key_layer, value_layer)\n        else:\n            kv_cache = None\n\n        if self.multi_query_attention:\n            key_layer = key_layer.unsqueeze(-2)\n            key_layer = key_layer.expand(\n                -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1\n            )\n            key_layer = key_layer.contiguous().view(\n                key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)\n            )\n            value_layer = value_layer.unsqueeze(-2)\n            value_layer = value_layer.expand(\n                -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1\n            )\n            value_layer = value_layer.contiguous().view(\n                value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)\n            )\n\n        # ==================================\n        # core attention computation\n        # ==================================\n\n        context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)\n\n        # =================\n        # Output. [sq, b, h]\n        # =================\n\n        output = self.dense(context_layer)\n\n        return output, kv_cache\n\n\ndef _config_to_kwargs(args):\n    common_kwargs = {\n        \"dtype\": args.torch_dtype,\n    }\n    return common_kwargs\n\n\nclass MLP(torch.nn.Module):\n    \"\"\"MLP.\n\n    MLP will take the input with h hidden state, project it to 4*h\n    hidden dimension, perform nonlinear transformation, and project the\n    state back into h hidden dimension.\n    \"\"\"\n\n    def __init__(self, config: ChatGLMConfig, device=None):\n        super(MLP, self).__init__()\n\n        self.add_bias = config.add_bias_linear\n\n        # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf\n        self.dense_h_to_4h = nn.Linear(\n            config.hidden_size,\n            config.ffn_hidden_size * 2,\n            bias=self.add_bias,\n            device=device,\n            **_config_to_kwargs(config)\n        )\n\n        def swiglu(x):\n            x = torch.chunk(x, 2, dim=-1)\n            return F.silu(x[0]) * x[1]\n\n        self.activation_func = swiglu\n\n        # Project back to h.\n        self.dense_4h_to_h = nn.Linear(\n            config.ffn_hidden_size,\n            config.hidden_size,\n            bias=self.add_bias,\n            device=device,\n            **_config_to_kwargs(config)\n        )\n\n    def forward(self, hidden_states):\n        # [s, b, 4hp]\n        intermediate_parallel = self.dense_h_to_4h(hidden_states)\n        intermediate_parallel = self.activation_func(intermediate_parallel)\n        # [s, b, h]\n        output = self.dense_4h_to_h(intermediate_parallel)\n        return output\n\n\nclass GLMBlock(torch.nn.Module):\n    \"\"\"A single transformer layer.\n\n    Transformer layer takes input with size [s, b, h] and returns an\n    output of the same size.\n    \"\"\"\n\n    def __init__(self, config: ChatGLMConfig, layer_number, device=None):\n        super(GLMBlock, self).__init__()\n        self.layer_number = layer_number\n\n        self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm\n\n        self.fp32_residual_connection = config.fp32_residual_connection\n\n        LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm\n        # Layernorm on the input data.\n        self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,\n                                             dtype=config.torch_dtype)\n\n        # Self attention.\n        self.self_attention = SelfAttention(config, layer_number, device=device)\n        self.hidden_dropout = config.hidden_dropout\n\n        # Layernorm on the attention output\n        self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,\n                                                      dtype=config.torch_dtype)\n\n        # MLP\n        self.mlp = MLP(config, device=device)\n\n    def forward(\n            self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,\n    ):\n        # hidden_states: [s, b, h]\n\n        # Layer norm at the beginning of the transformer layer.\n        layernorm_output = self.input_layernorm(hidden_states)\n        # Self attention.\n        attention_output, kv_cache = self.self_attention(\n            layernorm_output,\n            attention_mask,\n            rotary_pos_emb,\n            kv_cache=kv_cache,\n            use_cache=use_cache\n        )\n\n        # Residual connection.\n        if self.apply_residual_connection_post_layernorm:\n            residual = layernorm_output\n        else:\n            residual = hidden_states\n\n        layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)\n        layernorm_input = residual + layernorm_input\n\n        # Layer norm post the self attention.\n        layernorm_output = self.post_attention_layernorm(layernorm_input)\n\n        # MLP.\n        mlp_output = self.mlp(layernorm_output)\n\n        # Second residual connection.\n        if self.apply_residual_connection_post_layernorm:\n            residual = layernorm_output\n        else:\n            residual = layernorm_input\n\n        output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)\n        output = residual + output\n\n        return output, kv_cache\n\n\nclass GLMTransformer(torch.nn.Module):\n    \"\"\"Transformer class.\"\"\"\n\n    def __init__(self, config: ChatGLMConfig, device=None):\n        super(GLMTransformer, self).__init__()\n\n        self.fp32_residual_connection = config.fp32_residual_connection\n        self.post_layer_norm = config.post_layer_norm\n\n        # Number of layers.\n        self.num_layers = config.num_layers\n\n        # Transformer layers.\n        def build_layer(layer_number):\n            return GLMBlock(config, layer_number, device=device)\n\n        self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])\n\n        if self.post_layer_norm:\n            LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm\n            # Final layer norm before output.\n            self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,\n                                                 dtype=config.torch_dtype)\n\n        self.gradient_checkpointing = False\n\n    def _get_layer(self, layer_number):\n        return self.layers[layer_number]\n\n    def forward(\n            self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,\n            use_cache: Optional[bool] = True,\n            output_hidden_states: Optional[bool] = False,\n    ):\n        if not kv_caches:\n            kv_caches = [None for _ in range(self.num_layers)]\n        presents = () if use_cache else None\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                logger.warning_once(\n                    \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                )\n                use_cache = False\n\n        all_self_attentions = None\n        all_hidden_states = () if output_hidden_states else None\n        for index in range(self.num_layers):\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_states,)\n\n            layer = self._get_layer(index)\n            if self.gradient_checkpointing and self.training:\n                layer_ret = torch.utils.checkpoint.checkpoint(\n                    layer,\n                    hidden_states,\n                    attention_mask,\n                    rotary_pos_emb,\n                    kv_caches[index],\n                    use_cache\n                )\n            else:\n                layer_ret = layer(\n                    hidden_states,\n                    attention_mask,\n                    rotary_pos_emb,\n                    kv_cache=kv_caches[index],\n                    use_cache=use_cache\n                )\n            hidden_states, kv_cache = layer_ret\n            if use_cache:\n                presents = presents + (kv_cache,)\n\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_states,)\n\n        # Final layer norm.\n        if self.post_layer_norm:\n            hidden_states = self.final_layernorm(hidden_states)\n\n        return hidden_states, presents, all_hidden_states, all_self_attentions\n\n\nclass ChatGLMPreTrainedModel(PreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and\n    a simple interface for downloading and loading pretrained models.\n    \"\"\"\n\n    is_parallelizable = False\n    supports_gradient_checkpointing = True\n    config_class = ChatGLMConfig\n    base_model_prefix = \"transformer\"\n    _no_split_modules = [\"GLMBlock\"]\n\n    def _init_weights(self, module: nn.Module):\n        \"\"\"Initialize the weights.\"\"\"\n        return\n\n    def get_masks(self, input_ids, past_key_values, padding_mask=None):\n        batch_size, seq_length = input_ids.shape\n        full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)\n        full_attention_mask.tril_()\n        past_length = 0\n        if past_key_values:\n            past_length = past_key_values[0][0].shape[0]\n        if past_length:\n            full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,\n                                                        device=input_ids.device), full_attention_mask), dim=-1)\n        if padding_mask is not None:\n            full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)\n        if not past_length and padding_mask is not None:\n            full_attention_mask -= padding_mask.unsqueeze(-1) - 1\n        full_attention_mask = (full_attention_mask < 0.5).bool()\n        full_attention_mask.unsqueeze_(1)\n        return full_attention_mask\n\n    def get_position_ids(self, input_ids, device):\n        batch_size, seq_length = input_ids.shape\n        position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)\n        return position_ids\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, GLMTransformer):\n            module.gradient_checkpointing = value\n\n\nclass Embedding(torch.nn.Module):\n    \"\"\"Language model embeddings.\"\"\"\n\n    def __init__(self, config: ChatGLMConfig, device=None):\n        super(Embedding, self).__init__()\n\n        self.hidden_size = config.hidden_size\n        # Word embeddings (parallel).\n        self.word_embeddings = nn.Embedding(\n            config.padded_vocab_size,\n            self.hidden_size,\n            dtype=config.torch_dtype,\n            device=device\n        )\n        self.fp32_residual_connection = config.fp32_residual_connection\n\n    def forward(self, input_ids):\n        # Embeddings.\n        words_embeddings = self.word_embeddings(input_ids)\n        embeddings = words_embeddings\n        # Data format change to avoid explicit tranposes : [b s h] --> [s b h].\n        embeddings = embeddings.transpose(0, 1).contiguous()\n        # If the input flag for fp32 residual connection is set, convert for float.\n        if self.fp32_residual_connection:\n            embeddings = embeddings.float()\n        return embeddings\n\n\nclass ChatGLMModel(ChatGLMPreTrainedModel):\n    def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):\n        super().__init__(config)\n        if empty_init:\n            init_method = skip_init\n        else:\n            init_method = default_init\n        init_kwargs = {}\n        if device is not None:\n            init_kwargs[\"device\"] = device\n        self.embedding = init_method(Embedding, config, **init_kwargs)\n        self.num_layers = config.num_layers\n        self.multi_query_group_num = config.multi_query_group_num\n        self.kv_channels = config.kv_channels\n\n        # Rotary positional embeddings\n        self.seq_length = config.seq_length\n        rotary_dim = (\n            config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels\n        )\n\n        self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,\n                                              dtype=config.torch_dtype)\n        self.encoder = init_method(GLMTransformer, config, **init_kwargs)\n        self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,\n                                        dtype=config.torch_dtype, **init_kwargs)\n        self.pre_seq_len = config.pre_seq_len\n        self.prefix_projection = config.prefix_projection\n        if self.pre_seq_len is not None:\n            for param in self.parameters():\n                param.requires_grad = False\n            self.prefix_tokens = torch.arange(self.pre_seq_len).long()\n            self.prefix_encoder = PrefixEncoder(config)\n            self.dropout = torch.nn.Dropout(0.1)\n\n    def get_input_embeddings(self):\n        return self.embedding.word_embeddings\n\n    def get_prompt(self, batch_size, device, dtype=torch.half):\n        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)\n        past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)\n        past_key_values = past_key_values.view(\n            batch_size,\n            self.pre_seq_len,\n            self.num_layers * 2,\n            self.multi_query_group_num,\n            self.kv_channels\n        )\n        # seq_len, b, nh, hidden_size\n        past_key_values = self.dropout(past_key_values)\n        past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)\n        return past_key_values\n\n    def forward(\n            self,\n            input_ids,\n            position_ids: Optional[torch.Tensor] = None,\n            attention_mask: Optional[torch.BoolTensor] = None,\n            full_attention_mask: Optional[torch.BoolTensor] = None,\n            past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,\n            inputs_embeds: Optional[torch.Tensor] = None,\n            use_cache: Optional[bool] = None,\n            output_hidden_states: Optional[bool] = None,\n            return_dict: Optional[bool] = None,\n    ):\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        batch_size, seq_length = input_ids.shape\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embedding(input_ids)\n\n        if self.pre_seq_len is not None:\n            if past_key_values is None:\n                past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,\n                                                  dtype=inputs_embeds.dtype)\n            if attention_mask is not None:\n                attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),\n                                            attention_mask], dim=-1)\n\n        if full_attention_mask is None:\n            if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):\n                full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)\n\n        # Rotary positional embeddings\n        rotary_pos_emb = self.rotary_pos_emb(self.seq_length)\n        if position_ids is not None:\n            rotary_pos_emb = rotary_pos_emb[position_ids]\n        else:\n            rotary_pos_emb = rotary_pos_emb[None, :seq_length]\n        rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()\n\n        # Run encoder.\n        hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(\n            inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,\n            kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states\n        )\n\n        if not return_dict:\n            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)\n\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=presents,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attentions,\n        )\n\n    def quantize(self, weight_bit_width: int):\n        from .quantization import quantize\n        quantize(self.encoder, weight_bit_width)\n        return self\n\n\nclass ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):\n    def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):\n        super().__init__(config)\n\n        self.max_sequence_length = config.max_length\n        self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)\n        self.config = config\n        self.quantized = False\n\n        if self.config.quantization_bit:\n            self.quantize(self.config.quantization_bit, empty_init=True)\n\n    def _update_model_kwargs_for_generation(\n            self,\n            outputs: ModelOutput,\n            model_kwargs: Dict[str, Any],\n            is_encoder_decoder: bool = False,\n            standardize_cache_format: bool = False,\n    ) -> Dict[str, Any]:\n        # update past_key_values\n        model_kwargs[\"past_key_values\"] = self._extract_past_from_model_output(\n            outputs, standardize_cache_format=standardize_cache_format\n        )\n\n        # update attention mask\n        if \"attention_mask\" in model_kwargs:\n            attention_mask = model_kwargs[\"attention_mask\"]\n            model_kwargs[\"attention_mask\"] = torch.cat(\n                [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1\n            )\n\n        # update position ids\n        if \"position_ids\" in model_kwargs:\n            position_ids = model_kwargs[\"position_ids\"]\n            new_position_id = position_ids[..., -1:].clone()\n            new_position_id += 1\n            model_kwargs[\"position_ids\"] = torch.cat(\n                [position_ids, new_position_id], dim=-1\n            )\n\n        model_kwargs[\"is_first_forward\"] = False\n        return model_kwargs\n\n    def prepare_inputs_for_generation(\n            self,\n            input_ids: torch.LongTensor,\n            past_key_values: Optional[torch.Tensor] = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            position_ids: Optional[torch.Tensor] = None,\n            use_cache: Optional[bool] = None,\n            is_first_forward: bool = True,\n            **kwargs\n    ) -> dict:\n        # only last token for input_ids if past is not None\n        if position_ids is None:\n            position_ids = self.get_position_ids(input_ids, device=input_ids.device)\n        if not is_first_forward:\n            if past_key_values is not None:\n                position_ids = position_ids[..., -1:]\n                input_ids = input_ids[:, -1:]\n        return {\n            \"input_ids\": input_ids,\n            \"past_key_values\": past_key_values,\n            \"position_ids\": position_ids,\n            \"attention_mask\": attention_mask,\n            \"return_last_logit\": True,\n            \"use_cache\": use_cache\n        }\n\n    def forward(\n            self,\n            input_ids: Optional[torch.Tensor] = None,\n            position_ids: Optional[torch.Tensor] = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            past_key_values: Optional[Tuple[torch.FloatTensor]] = None,\n            inputs_embeds: Optional[torch.Tensor] = None,\n            labels: Optional[torch.Tensor] = None,\n            use_cache: Optional[bool] = None,\n            output_attentions: Optional[bool] = None,\n            output_hidden_states: Optional[bool] = None,\n            return_dict: Optional[bool] = None,\n            return_last_logit: Optional[bool] = False,\n    ):\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        transformer_outputs = self.transformer(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            attention_mask=attention_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        hidden_states = transformer_outputs[0]\n        if return_last_logit:\n            hidden_states = hidden_states[-1:]\n        lm_logits = self.transformer.output_layer(hidden_states)\n        lm_logits = lm_logits.transpose(0, 1).contiguous()\n\n        loss = None\n        if labels is not None:\n            lm_logits = lm_logits.to(torch.float32)\n\n            # Shift so that tokens < n predict n\n            shift_logits = lm_logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss(ignore_index=self.config.pad_token_id)\n            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))\n\n            lm_logits = lm_logits.to(hidden_states.dtype)\n            loss = loss.to(hidden_states.dtype)\n\n        if not return_dict:\n            output = (lm_logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=lm_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n\n    @staticmethod\n    def _reorder_cache(\n            past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor\n    ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:\n        \"\"\"\n        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or\n        [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct\n        beam_idx at every generation step.\n\n        Output shares the same memory storage as `past`.\n        \"\"\"\n        return tuple(\n            (\n                layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),\n                layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),\n            )\n            for layer_past in past\n        )\n\n    def process_response(self, output, history):\n        content = \"\"\n        history = deepcopy(history)\n        for response in output.split(\"<|assistant|>\"):\n            metadata, content = response.split(\"\\n\", maxsplit=1)\n            if not metadata.strip():\n                content = content.strip()\n                history.append({\"role\": \"assistant\", \"metadata\": metadata, \"content\": content})\n                content = content.replace(\"[[训练时间]]\", \"2023年\")\n            else:\n                history.append({\"role\": \"assistant\", \"metadata\": metadata, \"content\": content})\n                if history[0][\"role\"] == \"system\" and \"tools\" in history[0]:\n                    content = \"\\n\".join(content.split(\"\\n\")[1:-1])\n                    def tool_call(**kwargs):\n                        return kwargs\n                    parameters = eval(content)\n                    content = {\"name\": metadata.strip(), \"parameters\": parameters}\n                else:\n                    content = {\"name\": metadata.strip(), \"content\": content}\n        return content, history\n\n    @torch.inference_mode()\n    def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = \"user\",\n             max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,\n             **kwargs):\n        if history is None:\n            history = []\n        if logits_processor is None:\n            logits_processor = LogitsProcessorList()\n        logits_processor.append(InvalidScoreLogitsProcessor())\n        gen_kwargs = {\"max_length\": max_length, \"num_beams\": num_beams, \"do_sample\": do_sample, \"top_p\": top_p,\n                      \"temperature\": temperature, \"logits_processor\": logits_processor, **kwargs}\n        inputs = tokenizer.build_chat_input(query, history=history, role=role)\n        inputs = inputs.to(self.device)\n        eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command(\"<|user|>\"),\n                        tokenizer.get_command(\"<|observation|>\")]\n        outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)\n        outputs = outputs.tolist()[0][len(inputs[\"input_ids\"][0]):-1]\n        response = tokenizer.decode(outputs)\n        history.append({\"role\": role, \"content\": query})\n        response, history = self.process_response(response, history)\n        return response, history\n\n    @torch.inference_mode()\n    def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = \"user\",\n                    past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,\n                    logits_processor=None, return_past_key_values=False, **kwargs):\n        if history is None:\n            history = []\n        if logits_processor is None:\n            logits_processor = LogitsProcessorList()\n        logits_processor.append(InvalidScoreLogitsProcessor())\n        eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command(\"<|user|>\"),\n                        tokenizer.get_command(\"<|observation|>\")]\n        gen_kwargs = {\"max_length\": max_length, \"do_sample\": do_sample, \"top_p\": top_p,\n                      \"temperature\": temperature, \"logits_processor\": logits_processor, **kwargs}\n        if past_key_values is None:\n            inputs = tokenizer.build_chat_input(query, history=history, role=role)\n        else:\n            inputs = tokenizer.build_chat_input(query, role=role)\n        inputs = inputs.to(self.device)\n        if past_key_values is not None:\n            past_length = past_key_values[0][0].shape[0]\n            if self.transformer.pre_seq_len is not None:\n                past_length -= self.transformer.pre_seq_len\n            inputs.position_ids += past_length\n            attention_mask = inputs.attention_mask\n            attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)\n            inputs['attention_mask'] = attention_mask\n        history.append({\"role\": role, \"content\": query})\n        for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,\n                                            eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,\n                                            **gen_kwargs):\n            if return_past_key_values:\n                outputs, past_key_values = outputs\n            outputs = outputs.tolist()[0][len(inputs[\"input_ids\"][0]):-1]\n            response = tokenizer.decode(outputs)\n            if response and response[-1] != \"�\":\n                response, new_history = self.process_response(response, history)\n                if return_past_key_values:\n                    yield response, new_history, past_key_values\n                else:\n                    yield response, new_history\n\n    @torch.inference_mode()\n    def stream_generate(\n            self,\n            input_ids,\n            generation_config: Optional[GenerationConfig] = None,\n            logits_processor: Optional[LogitsProcessorList] = None,\n            stopping_criteria: Optional[StoppingCriteriaList] = None,\n            prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,\n            return_past_key_values=False,\n            **kwargs,\n    ):\n        batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]\n\n        if generation_config is None:\n            generation_config = self.generation_config\n        generation_config = copy.deepcopy(generation_config)\n        model_kwargs = generation_config.update(**kwargs)\n        model_kwargs[\"use_cache\"] = generation_config.use_cache\n        bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id\n\n        if isinstance(eos_token_id, int):\n            eos_token_id = [eos_token_id]\n        eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None\n\n        has_default_max_length = kwargs.get(\"max_length\") is None and generation_config.max_length is not None\n        if has_default_max_length and generation_config.max_new_tokens is None:\n            warnings.warn(\n                f\"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. \"\n                \"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we\"\n                \" recommend using `max_new_tokens` to control the maximum length of the generation.\",\n                UserWarning,\n            )\n        elif generation_config.max_new_tokens is not None:\n            generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length\n            if not has_default_max_length:\n                logger.warn(\n                    f\"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=\"\n                    f\"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. \"\n                    \"Please refer to the documentation for more information. \"\n                    \"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)\",\n                    UserWarning,\n                )\n\n        if input_ids_seq_length >= generation_config.max_length:\n            input_ids_string = \"decoder_input_ids\" if self.config.is_encoder_decoder else \"input_ids\"\n            logger.warning(\n                f\"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to\"\n                f\" {generation_config.max_length}. This can lead to unexpected behavior. You should consider\"\n                \" increasing `max_new_tokens`.\"\n            )\n\n        # 2. Set generation parameters if not already defined\n        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()\n        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()\n\n        logits_processor = self._get_logits_processor(\n            generation_config=generation_config,\n            input_ids_seq_length=input_ids_seq_length,\n            encoder_input_ids=input_ids,\n            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,\n            logits_processor=logits_processor,\n        )\n\n        stopping_criteria = self._get_stopping_criteria(\n            generation_config=generation_config, stopping_criteria=stopping_criteria\n        )\n        logits_warper = self._get_logits_warper(generation_config)\n\n        unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)\n        scores = None\n        while True:\n            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)\n            # forward pass to get next token\n            outputs = self(\n                **model_inputs,\n                return_dict=True,\n                output_attentions=False,\n                output_hidden_states=False,\n            )\n\n            next_token_logits = outputs.logits[:, -1, :]\n\n            # pre-process distribution\n            next_token_scores = logits_processor(input_ids, next_token_logits)\n            next_token_scores = logits_warper(input_ids, next_token_scores)\n\n            # sample\n            probs = nn.functional.softmax(next_token_scores, dim=-1)\n            if generation_config.do_sample:\n                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)\n            else:\n                next_tokens = torch.argmax(probs, dim=-1)\n            # update generated ids, model inputs, and length for next step\n            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)\n            model_kwargs = self._update_model_kwargs_for_generation(\n                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder\n            )\n            unfinished_sequences = unfinished_sequences.mul(\n                next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)\n            )\n            if return_past_key_values:\n                yield input_ids, outputs.past_key_values\n            else:\n                yield input_ids\n            # stop when each sentence is finished, or if we exceed the maximum length\n            if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):\n                break\n\n    def quantize(self, bits: int, empty_init=False, device=None, **kwargs):\n        if bits == 0:\n            return\n\n        from .quantization import quantize\n\n        if self.quantized:\n            logger.info(\"Already quantized.\")\n            return self\n\n        self.quantized = True\n\n        self.config.quantization_bit = bits\n\n        self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device,\n                                            **kwargs)\n        return self\n\n\nclass ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):\n    def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):\n        super().__init__(config)\n\n        self.num_labels = config.num_labels\n        self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)\n\n        self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)\n        if config.classifier_dropout is not None:\n            self.dropout = nn.Dropout(config.classifier_dropout)\n        else:\n            self.dropout = None\n        self.config = config\n\n        if self.config.quantization_bit:\n            self.quantize(self.config.quantization_bit, empty_init=True)\n\n    def forward(\n            self,\n            input_ids: Optional[torch.LongTensor] = None,\n            position_ids: Optional[torch.LongTensor] = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            full_attention_mask: Optional[torch.Tensor] = None,\n            past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,\n            inputs_embeds: Optional[torch.LongTensor] = None,\n            labels: Optional[torch.LongTensor] = None,\n            use_cache: Optional[bool] = None,\n            output_hidden_states: Optional[bool] = None,\n            return_dict: Optional[bool] = None,\n    ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        transformer_outputs = self.transformer(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            attention_mask=attention_mask,\n            full_attention_mask=full_attention_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        hidden_states = transformer_outputs[0]\n        pooled_hidden_states = hidden_states[-1]\n        if self.dropout is not None:\n            pooled_hidden_states = self.dropout(pooled_hidden_states)\n        logits = self.classifier_head(pooled_hidden_states)\n\n        loss = None\n        if labels is not None:\n            if self.config.problem_type is None:\n                if self.num_labels == 1:\n                    self.config.problem_type = \"regression\"\n                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n                    self.config.problem_type = \"single_label_classification\"\n                else:\n                    self.config.problem_type = \"multi_label_classification\"\n\n            if self.config.problem_type == \"regression\":\n                loss_fct = MSELoss()\n                if self.num_labels == 1:\n                    loss = loss_fct(logits.squeeze().float(), labels.squeeze())\n                else:\n                    loss = loss_fct(logits.float(), labels)\n            elif self.config.problem_type == \"single_label_classification\":\n                loss_fct = CrossEntropyLoss()\n                loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))\n            elif self.config.problem_type == \"multi_label_classification\":\n                loss_fct = BCEWithLogitsLoss()\n                loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))\n\n        if not return_dict:\n            output = (logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return SequenceClassifierOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n"
  },
  {
    "path": "src/resources/models/chatglm3-6B/quantization.py",
    "content": "from torch.nn import Linear\nfrom torch.nn.parameter import Parameter\n\nimport bz2\nimport torch\nimport base64\nimport ctypes\nfrom transformers.utils import logging\n\nfrom typing import List\nfrom functools import partial\n\nlogger = logging.get_logger(__name__)\n\ntry:\n    from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up\n\n    class Kernel:\n        def __init__(self, code: bytes, function_names: List[str]):\n            self.code = code\n            self._function_names = function_names\n            self._cmodule = LazyKernelCModule(self.code)\n\n            for name in self._function_names:\n                setattr(self, name, KernelFunction(self._cmodule, name))\n\n    quantization_code = \"$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ\"\n\n    kernels = Kernel(\n        bz2.decompress(base64.b64decode(quantization_code)),\n        [\n            \"int4WeightCompression\",\n            \"int4WeightExtractionFloat\",\n            \"int4WeightExtractionHalf\",\n            \"int8WeightExtractionFloat\",\n            \"int8WeightExtractionHalf\",\n        ],\n    )\nexcept Exception as exception:\n    kernels = None\n    logger.warning(\"Failed to load cpm_kernels:\" + str(exception))\n\n\nclass W8A16Linear(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):\n        ctx.inp_shape = inp.size()\n        ctx.weight_bit_width = weight_bit_width\n        out_features = quant_w.size(0)\n        inp = inp.contiguous().view(-1, inp.size(-1))\n        weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)\n        ctx.weight_shape = weight.size()\n        output = inp.mm(weight.t())\n        ctx.save_for_backward(inp, quant_w, scale_w)\n        return output.view(*(ctx.inp_shape[:-1] + (out_features,)))\n\n    @staticmethod\n    def backward(ctx, grad_output: torch.Tensor):\n        inp, quant_w, scale_w = ctx.saved_tensors\n        weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width)\n        grad_output = grad_output.contiguous().view(-1, weight.size(0))\n        grad_input = grad_output.mm(weight)\n        grad_weight = grad_output.t().mm(inp)\n        return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None\n\n\ndef compress_int4_weight(weight: torch.Tensor):  # (n, m)\n    with torch.cuda.device(weight.device):\n        n, m = weight.size(0), weight.size(1)\n        assert m % 2 == 0\n        m = m // 2\n        out = torch.empty(n, m, dtype=torch.int8, device=\"cuda\")\n        stream = torch.cuda.current_stream()\n\n        gridDim = (n, 1, 1)\n        blockDim = (min(round_up(m, 32), 1024), 1, 1)\n\n        kernels.int4WeightCompression(\n            gridDim,\n            blockDim,\n            0,\n            stream,\n            [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],\n        )\n        return out\n\n\ndef extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):\n    assert scale_list.dtype in [torch.half, torch.bfloat16]\n    assert weight.dtype in [torch.int8]\n    if source_bit_width == 8:\n        return weight.to(scale_list.dtype) * scale_list[:, None]\n    elif source_bit_width == 4:\n        func = (\n            kernels.int4WeightExtractionHalf if scale_list.dtype == torch.half else kernels.int4WeightExtractionBFloat16\n        )\n    else:\n        assert False, \"Unsupported bit-width\"\n\n    with torch.cuda.device(weight.device):\n        n, m = weight.size(0), weight.size(1)\n        out = torch.empty(n, m * (8 // source_bit_width), dtype=scale_list.dtype, device=\"cuda\")\n        stream = torch.cuda.current_stream()\n\n        gridDim = (n, 1, 1)\n        blockDim = (min(round_up(m, 32), 1024), 1, 1)\n\n        func(\n            gridDim,\n            blockDim,\n            0,\n            stream,\n            [\n                ctypes.c_void_p(weight.data_ptr()),\n                ctypes.c_void_p(scale_list.data_ptr()),\n                ctypes.c_void_p(out.data_ptr()),\n                ctypes.c_int32(n),\n                ctypes.c_int32(m),\n            ],\n        )\n        return out\n\n\nclass QuantizedLinear(torch.nn.Module):\n    def __init__(self, weight_bit_width: int, weight, bias=None, device=\"cpu\", dtype=None, empty_init=False, *args,\n                 **kwargs):\n        super().__init__()\n        self.weight_bit_width = weight_bit_width\n\n        shape = weight.shape\n\n        if weight is None or empty_init:\n            self.weight = torch.empty(shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=device)\n            self.weight_scale = torch.empty(shape[0], dtype=dtype, device=device)\n        else:\n            self.weight_scale = weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)\n            self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)\n            if weight_bit_width == 4:\n                self.weight = compress_int4_weight(self.weight)\n\n        self.weight = Parameter(self.weight.to(device), requires_grad=False)\n        self.weight_scale = Parameter(self.weight_scale.to(device), requires_grad=False)\n        self.bias = Parameter(bias.to(device), requires_grad=False) if bias is not None else None\n\n    def forward(self, input):\n        output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width)\n        if self.bias is not None:\n            output = output + self.bias\n        return output\n\n\ndef quantize(model, weight_bit_width, empty_init=False, device=None):\n    \"\"\"Replace fp16 linear with quantized linear\"\"\"\n    for layer in model.layers:\n        layer.self_attention.query_key_value = QuantizedLinear(\n            weight_bit_width=weight_bit_width,\n            weight=layer.self_attention.query_key_value.weight.to(torch.cuda.current_device()),\n            bias=layer.self_attention.query_key_value.bias,\n            dtype=layer.self_attention.query_key_value.weight.dtype,\n            device=layer.self_attention.query_key_value.weight.device if device is None else device,\n            empty_init=empty_init\n        )\n        layer.self_attention.dense = QuantizedLinear(\n            weight_bit_width=weight_bit_width,\n            weight=layer.self_attention.dense.weight.to(torch.cuda.current_device()),\n            bias=layer.self_attention.dense.bias,\n            dtype=layer.self_attention.dense.weight.dtype,\n            device=layer.self_attention.dense.weight.device if device is None else device,\n            empty_init=empty_init\n        )\n        layer.mlp.dense_h_to_4h = QuantizedLinear(\n            weight_bit_width=weight_bit_width,\n            weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),\n            bias=layer.mlp.dense_h_to_4h.bias,\n            dtype=layer.mlp.dense_h_to_4h.weight.dtype,\n            device=layer.mlp.dense_h_to_4h.weight.device if device is None else device,\n            empty_init=empty_init\n        )\n        layer.mlp.dense_4h_to_h = QuantizedLinear(\n            weight_bit_width=weight_bit_width,\n            weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),\n            bias=layer.mlp.dense_4h_to_h.bias,\n            dtype=layer.mlp.dense_4h_to_h.weight.dtype,\n            device=layer.mlp.dense_4h_to_h.weight.device if device is None else device,\n            empty_init=empty_init\n        )\n\n    return model\n"
  },
  {
    "path": "src/resources/models/chatglm3-6B/tokenization_chatglm.py",
    "content": "import json\nimport os\nimport torch\nfrom typing import List, Optional, Union, Dict\nfrom sentencepiece import SentencePieceProcessor\nfrom transformers import PreTrainedTokenizer\nfrom transformers.utils import logging, PaddingStrategy\nfrom transformers.tokenization_utils_base import EncodedInput, BatchEncoding\n\n\nclass SPTokenizer:\n    def __init__(self, model_path: str):\n        # reload tokenizer\n        assert os.path.isfile(model_path), model_path\n        self.sp_model = SentencePieceProcessor(model_file=model_path)\n\n        # BOS / EOS token IDs\n        self.n_words: int = self.sp_model.vocab_size()\n        self.bos_id: int = self.sp_model.bos_id()\n        self.eos_id: int = self.sp_model.eos_id()\n        # self.pad_id: int = self.sp_model.unk_id()\n        self.pad_id: int = -100\n        assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()\n\n        special_tokens = [\"[MASK]\", \"[gMASK]\", \"[sMASK]\", \"sop\", \"eop\", \"<|system|>\", \"<|user|>\", \"<|assistant|>\",\n                          \"<|observation|>\"]\n        self.special_tokens = {}\n        self.index_special_tokens = {}\n        for token in special_tokens:\n            self.special_tokens[token] = self.n_words\n            self.index_special_tokens[self.n_words] = token\n            self.n_words += 1\n\n    def tokenize(self, s: str):\n        return self.sp_model.EncodeAsPieces(s)\n\n    def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:\n        assert type(s) is str\n        t = self.sp_model.encode(s)\n        if bos:\n            t = [self.bos_id] + t\n        if eos:\n            t = t + [self.eos_id]\n        return t\n\n    def decode(self, t: List[int]) -> str:\n        text, buffer = \"\", []\n        for token in t:\n            if token in self.index_special_tokens:\n                if buffer:\n                    text += self.sp_model.decode(buffer)\n                    buffer = []\n                text += self.index_special_tokens[token]\n            else:\n                buffer.append(token)\n        if buffer:\n            text += self.sp_model.decode(buffer)\n        return text\n\n    def decode_tokens(self, tokens: List[str]) -> str:\n        text = self.sp_model.DecodePieces(tokens)\n        return text\n\n    def convert_token_to_id(self, token):\n        \"\"\" Converts a token (str) in an id using the vocab. \"\"\"\n        if token in self.special_tokens:\n            return self.special_tokens[token]\n        return self.sp_model.PieceToId(token)\n\n    def convert_id_to_token(self, index):\n        \"\"\"Converts an index (integer) in a token (str) using the vocab.\"\"\"\n        if index in self.index_special_tokens:\n            return self.index_special_tokens[index]\n        if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:\n            return \"\"\n        return self.sp_model.IdToPiece(index)\n\n\nclass ChatGLMTokenizer(PreTrainedTokenizer):\n    vocab_files_names = {\"vocab_file\": \"tokenizer.model\"}\n\n    model_input_names = [\"input_ids\", \"attention_mask\", \"position_ids\"]\n\n    def __init__(self, vocab_file, padding_side=\"left\", clean_up_tokenization_spaces=False, **kwargs):\n        self.name = \"GLMTokenizer\"\n\n        self.vocab_file = vocab_file\n        self.tokenizer = SPTokenizer(vocab_file)\n        self.special_tokens = {\n            \"<bos>\": self.tokenizer.bos_id,\n            \"<eos>\": self.tokenizer.eos_id,\n            \"<pad>\": self.tokenizer.pad_id\n        }\n        super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs)\n\n    def get_command(self, token):\n        if token in self.special_tokens:\n            return self.special_tokens[token]\n        assert token in self.tokenizer.special_tokens, f\"{token} is not a special token for {self.name}\"\n        return self.tokenizer.special_tokens[token]\n\n    @property\n    def unk_token(self) -> str:\n        return \"<unk>\"\n\n    @property\n    def pad_token(self) -> str:\n        return \"<unk>\"\n\n    @property\n    def pad_token_id(self):\n        return self.get_command(\"<pad>\")\n\n    @property\n    def eos_token(self) -> str:\n        return \"</s>\"\n\n    @property\n    def eos_token_id(self):\n        return self.get_command(\"<eos>\")\n\n    @property\n    def vocab_size(self):\n        return self.tokenizer.n_words\n\n    def get_vocab(self):\n        \"\"\" Returns vocab as a dict \"\"\"\n        vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}\n        vocab.update(self.added_tokens_encoder)\n        return vocab\n\n    def _tokenize(self, text, **kwargs):\n        return self.tokenizer.tokenize(text)\n\n    def _convert_token_to_id(self, token):\n        \"\"\" Converts a token (str) in an id using the vocab. \"\"\"\n        return self.tokenizer.convert_token_to_id(token)\n\n    def _convert_id_to_token(self, index):\n        \"\"\"Converts an index (integer) in a token (str) using the vocab.\"\"\"\n        return self.tokenizer.convert_id_to_token(index)\n\n    def convert_tokens_to_string(self, tokens: List[str]) -> str:\n        return self.tokenizer.decode_tokens(tokens)\n\n    def save_vocabulary(self, save_directory, filename_prefix=None):\n        \"\"\"\n        Save the vocabulary and special tokens file to a directory.\n\n        Args:\n            save_directory (`str`):\n                The directory in which to save the vocabulary.\n            filename_prefix (`str`, *optional*):\n                An optional prefix to add to the named of the saved files.\n\n        Returns:\n            `Tuple(str)`: Paths to the files saved.\n        \"\"\"\n        if os.path.isdir(save_directory):\n            vocab_file = os.path.join(\n                save_directory, self.vocab_files_names[\"vocab_file\"]\n            )\n        else:\n            vocab_file = save_directory\n\n        with open(self.vocab_file, 'rb') as fin:\n            proto_str = fin.read()\n\n        with open(vocab_file, \"wb\") as writer:\n            writer.write(proto_str)\n\n        return (vocab_file,)\n\n    def get_prefix_tokens(self):\n        prefix_tokens = [self.get_command(\"[gMASK]\"), self.get_command(\"sop\")]\n        return prefix_tokens\n\n    def build_single_message(self, role, metadata, message):\n        assert role in [\"system\", \"user\", \"assistant\", \"observation\"], role\n        role_tokens = [self.get_command(f\"<|{role}|>\")] + self.tokenizer.encode(f\"{metadata}\\n\")\n        message_tokens = self.tokenizer.encode(message)\n        tokens = role_tokens + message_tokens\n        return tokens\n\n    def build_chat_input(self, query, history=None, role=\"user\"):\n        if history is None:\n            history = []\n        input_ids = []\n        for item in history:\n            content = item[\"content\"]\n            if item[\"role\"] == \"system\" and \"tools\" in item:\n                content = content + \"\\n\" + json.dumps(item[\"tools\"], indent=4, ensure_ascii=False)\n            input_ids.extend(self.build_single_message(item[\"role\"], item.get(\"metadata\", \"\"), content))\n        input_ids.extend(self.build_single_message(role, \"\", query))\n        input_ids.extend([self.get_command(\"<|assistant|>\")])\n        return self.batch_encode_plus([input_ids], return_tensors=\"pt\", is_split_into_words=True)\n\n    def build_inputs_with_special_tokens(\n            self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None\n    ) -> List[int]:\n        \"\"\"\n        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and\n        adding special tokens. A BERT sequence has the following format:\n\n        - single sequence: `[CLS] X [SEP]`\n        - pair of sequences: `[CLS] A [SEP] B [SEP]`\n\n        Args:\n            token_ids_0 (`List[int]`):\n                List of IDs to which the special tokens will be added.\n            token_ids_1 (`List[int]`, *optional*):\n                Optional second list of IDs for sequence pairs.\n\n        Returns:\n            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.\n        \"\"\"\n        prefix_tokens = self.get_prefix_tokens()\n        token_ids_0 = prefix_tokens + token_ids_0\n        if token_ids_1 is not None:\n            token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command(\"<eos>\")]\n        return token_ids_0\n\n    def _pad(\n            self,\n            encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],\n            max_length: Optional[int] = None,\n            padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,\n            pad_to_multiple_of: Optional[int] = None,\n            return_attention_mask: Optional[bool] = None,\n    ) -> dict:\n        \"\"\"\n        Pad encoded inputs (on left/right and up to predefined length or max length in the batch)\n\n        Args:\n            encoded_inputs:\n                Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).\n            max_length: maximum length of the returned list and optionally padding length (see below).\n                Will truncate by taking into account the special tokens.\n            padding_strategy: PaddingStrategy to use for padding.\n\n                - PaddingStrategy.LONGEST Pad to the longest sequence in the batch\n                - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)\n                - PaddingStrategy.DO_NOT_PAD: Do not pad\n                The tokenizer padding sides are defined in self.padding_side:\n\n                    - 'left': pads on the left of the sequences\n                    - 'right': pads on the right of the sequences\n            pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.\n                This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability\n                `>= 7.5` (Volta).\n            return_attention_mask:\n                (optional) Set to False to avoid returning attention mask (default: set to model specifics)\n        \"\"\"\n        # Load from model defaults\n        assert self.padding_side == \"left\"\n\n        required_input = encoded_inputs[self.model_input_names[0]]\n        seq_length = len(required_input)\n\n        if padding_strategy == PaddingStrategy.LONGEST:\n            max_length = len(required_input)\n\n        if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):\n            max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of\n\n        needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length\n\n        # Initialize attention mask if not present.\n        if \"attention_mask\" not in encoded_inputs:\n            encoded_inputs[\"attention_mask\"] = [1] * seq_length\n\n        if \"position_ids\" not in encoded_inputs:\n            encoded_inputs[\"position_ids\"] = list(range(seq_length))\n\n        if needs_to_be_padded:\n            difference = max_length - len(required_input)\n\n            if \"attention_mask\" in encoded_inputs:\n                encoded_inputs[\"attention_mask\"] = [0] * difference + encoded_inputs[\"attention_mask\"]\n            if \"position_ids\" in encoded_inputs:\n                encoded_inputs[\"position_ids\"] = [0] * difference + encoded_inputs[\"position_ids\"]\n            encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input\n\n        return encoded_inputs\n"
  },
  {
    "path": "src/resources/models/chatglm3-6B/tokenizer_config.json",
    "content": "{\n  \"name_or_path\": \"THUDM/chatglm3-6b\",\n  \"remove_space\": false,\n  \"do_lower_case\": false,\n  \"tokenizer_class\": \"ChatGLMTokenizer\",\n  \"auto_map\": {\n    \"AutoTokenizer\": [\n      \"tokenization_chatglm.ChatGLMTokenizer\",\n      null\n      ]\n  }\n}\n"
  },
  {
    "path": "src/resources/models/glm-10B-chinese/README.md",
    "content": "---\nlanguage:\n- zh\ntags:\n- glm\n- chatgpt\n---\n\nLink to github: [here](https://github.com/sunzeyeah/RLHF)\n\n---\n\n本仓库由[THUDM/glm-10b-chinese](https://huggingface.co/THUDM/glm-10b-chinese) fork而来，原仓库实现了PyTorch版本的GLM模型，该模型有100亿参数量，模型权重文件以FP32格式存储。\n\n本仓库在原始代码的基础上进行了部分调整，以支持ChatGPT训练pipeline，具体实现可参考：[sunzeyeah/RLHF](https://github.com/sunzeyeah/RLHF).\n\nThis repository is forked from [THUDM/glm-10b-chinese](https://huggingface.co/THUDM/glm-10b-chinese) that contains PyTorch implementation of GLM model with 10 billion parameters pretrained weights (FP32 precision).\n\nIt is slightly different from the original GLM implementation to support the ChatGPT training pipeline in this github repo: [sunzeyeah/RLHF](https://github.com/sunzeyeah/RLHF).\n\n---\n\n# Model description\nGLM is a General Language Model pretrained with an autoregressive blank-filling objective and can be finetuned on various natural language understanding and generation tasks.\n\nPlease refer to our paper for a detailed description of GLM:\n\n[GLM: General Language Model Pretraining with Autoregressive Blank Infilling](https://arxiv.org/abs/2103.10360) (ACL 2022)\n\nZhengxiao Du*, Yujie Qian*, Xiao Liu, Ming Ding, Jiezhong Qiu, Zhilin Yang, Jie Tang (*: equal contribution)\n\nFind more examples in our [Github repo](https://github.com/THUDM/GLM).\n\n`glm-10b-chinese` is pretrained on the [WuDaoCorpora](https://www.sciencedirect.com/science/article/pii/S2666651021000152) dataset. It has 48 transformer layers, with hidden size 4096 and 64 attention heads in each layer. The model is pretrained with autoregressive blank filling objectives designed for natural language understanding, seq2seq, and language modeling.\n\n---\n\n# Usage (Text Generation)\n```python\nfrom transformers import AutoTokenizer, AutoModelForSeq2SeqLM\ntokenizer = AutoTokenizer.from_pretrained(\"sunzeyeah/glm-10B-chinese\", trust_remote_code=True)\nmodel = AutoModelForSeq2SeqLM.from_pretrained(\"sunzeyeah/glm-10B-chinese\", trust_remote_code=True)\nmodel = model.half().cuda()\n\nmax_length = 512\nprompt = \"我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点\"\nprefix = \"回答：\"\nencoded_prompt = tokenizer(prompt, prefix + tokenizer.mask_token)\nprompt_length = len(encoded_prompt['input_ids'])\nencoded_dict = tokenizer(prompt, prefix + tokenizer.mask_token,\n                         max_length=min(prompt_length, max_length),\n                         truncation=\"only_first\",\n                         return_tensors=\"pt\",\n                         return_token_type_ids=False)\nmax_gen_length = max_length - encoded_dict['input_ids'].shape[1]\ninputs = tokenizer.build_inputs_for_generation(encoded_dict, max_gen_length=max_gen_length, padding=True)\ninputs = inputs.cuda()\noutputs = model.generate(**inputs,\n                         max_new_tokens=max_gen_length,\n                         eos_token_id=tokenizer.eop_token_id,\n                         pad_token_id=tokenizer.pad_token_id,\n                         do_sample=False,\n                         num_return_sequences=1,\n                         top_p=0.8,\n                         temperature=1.0)\nresults = tokenizer.batch_decode(outputs, skip_special_tokens=True)\nprint(results)\n```\n\n"
  },
  {
    "path": "src/resources/models/glm-10B-chinese/added_tokens.json",
    "content": "{\n  \"<|endoftext|>\": 50000,\n  \"[SEP]\": 50001,\n  \"[CLS]\": 50002,\n  \"[MASK]\": 50003,\n  \"[UNUSED1]\": 50004,\n  \"[UNUSED2]\": 50005,\n  \"<|startofpiece|>\": 50006,\n  \"<|endofpiece|>\": 50007,\n  \"[sMASK]\": 50008,\n  \"[gMASK]\": 50009\n}"
  },
  {
    "path": "src/resources/models/glm-10B-chinese/config.json",
    "content": "{\n  \"model_name_or_path\": \"sunzeyeah/glm-10B-chinese\",\n  \"architectures\": [\n    \"GLMModel\"\n  ],\n  \"attention_dropout_prob\": 0.1,\n  \"attention_scale\": 1.0,\n  \"auto_map\": {\n    \"AutoConfig\": \"configuration_glm.GLMConfig\",\n    \"AutoModel\": \"modeling_glm.GLMModel\",\n    \"AutoModelForSequenceClassification\": \"modeling_glm.GLMForSequenceClassification\",\n    \"AutoModelForSeq2SeqLM\": \"modeling_glm.GLMForConditionalGeneration\"\n    },\n  \"block_position_encoding\": true,\n  \"checkpoint_activations\": false,\n  \"checkpoint_num_layers\": 1,\n  \"embedding_dropout_prob\": 0.1,\n  \"hidden_size\": 4096,\n  \"initializer_range\": 0.02,\n  \"max_sequence_length\": 1024,\n  \"model_type\": \"glm\",\n  \"num_attention_heads\": 64,\n  \"num_layers\": 48,\n  \"output_dropout_prob\": 0.1,\n  \"output_predict\": true,\n  \"parallel_output\": true,\n  \"pool_token\": \"cls\",\n  \"relative_encoding\": false,\n  \"spell_func\": \"lstm\",\n  \"spell_length\": null,\n  \"torch_dtype\": \"float32\",\n  \"vocab_size\": 50048,\n  \"bad_words_id\": [50009],\n  \"pad_token_id\": 50000\n}"
  },
  {
    "path": "src/resources/models/glm-10B-chinese/configuration_glm.py",
    "content": "# coding=utf-8\n# Copyright 2022 shunxing1234 and The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\" GLM model configuration \"\"\"\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\nGLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {\n    \"shunxing1234/GLM\": \"https://huggingface.co/shunxing1234/GLM/resolve/main/config.json\",\n    # See all GLM models at https://huggingface.co/models?filter=glm\n}\n\n\nclass GLMConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`~GLMModel`].\n    It is used to instantiate an GLM model according to the specified arguments, defining the model\n    architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of\n    the GLM [shunxing1234/GLM-base-cased](https://huggingface.co/shunxing1234/GLM-base-cased) architecture.\n\n    Configuration objects inherit from  [`PretrainedConfig`] and can be used\n    to control the model outputs. Read the documentation from  [`PretrainedConfig`]\n    for more information.\n\n\n    Args:\n        vocab_size (`int`, *optional*, defaults to 30522):\n            Vocabulary size of the GLM model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`~GLMModel`] or\n            [`~TFGLMModel`].\n        hidden_size (`int`, *optional*, defaults to 768):\n            Dimension of the encoder layers and the pooler layer.\n        num_hidden_layers (`int`, *optional*, defaults to 12):\n            Number of hidden layers in the Transformer encoder.\n        num_attention_heads (`int`, *optional*, defaults to 12):\n            Number of attention heads for each attention layer in the Transformer encoder.\n        intermediate_size (`int`, *optional*, defaults to 3072):\n            Dimension of the \"intermediate\" (i.e., feed-forward) layer in the Transformer encoder.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"gelu\"`):\n            The non-linear activation function (function or string) in the encoder and pooler.\n            If string, `\"gelu\"`, `\"relu\"`, `\"selu\"` and `\"gelu_new\"` are supported.\n        hidden_dropout_prob (`float`, *optional*, defaults to 0.1):\n            The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.\n        attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):\n            The dropout ratio for the attention probabilities.\n        max_position_embeddings (`int`, *optional*, defaults to 512):\n            The maximum sequence length that this model might ever be used with.\n            Typically set this to something large just in case (e.g., 512 or 1024 or 2048).\n        type_vocab_size (`int`, *optional*, defaults to 2):\n            The vocabulary size of the `token_type_ids` passed when calling [`~GLMModel`] or\n            [`~TFGLMModel`].\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        layer_norm_eps (`float`, *optional*, defaults to 1e-12):\n            The epsilon used by the layer normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        Example:\n\n    ```python\n    >>> from transformers import GLMModel, GLMConfig\n\n    >>> # Initializing a GLM shunxing1234/GLM-base-cased style configuration\n    >>> configuration = GLMConfig()\n\n    >>> # Initializing a model from the shunxing1234/GLM-base-cased style configuration\n    >>> model = GLMModel(configuration)\n\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\n\"\"\"\n    model_type = \"glm\"\n    attribute_map = {\n        \"num_hidden_layers\": \"num_layers\"\n    }\n\n    def __init__(\n            self,\n            num_layers=24,\n            vocab_size=30592,\n            hidden_size=1024,\n            num_attention_heads=16,\n            embedding_dropout_prob=0.1,\n            attention_dropout_prob=0.1,\n            output_dropout_prob=0.1,\n            max_sequence_length=512,\n            checkpoint_activations=False,\n            checkpoint_num_layers=1,\n            parallel_output=True,\n            relative_encoding=False,\n            block_position_encoding=True,\n            output_predict=False,\n            spell_length=None,\n            spell_func=\"lstm\",\n            attention_scale=1.0,\n            initializer_range=0.02,\n            pool_token=\"cls\",\n            **kwargs\n    ):\n        self.num_layers = num_layers\n        self.vocab_size = vocab_size\n        self.hidden_size = hidden_size\n        self.num_attention_heads = num_attention_heads\n        self.embedding_dropout_prob = embedding_dropout_prob\n        self.attention_dropout_prob = attention_dropout_prob\n        self.output_dropout_prob = output_dropout_prob\n        self.max_sequence_length = max_sequence_length\n        self.checkpoint_activations = checkpoint_activations\n        self.checkpoint_num_layers = checkpoint_num_layers\n        self.parallel_output = parallel_output\n        self.relative_encoding = relative_encoding\n        self.block_position_encoding = block_position_encoding\n        self.output_predict = output_predict\n        self.spell_length = spell_length\n        self.spell_func = spell_func\n        self.attention_scale = attention_scale\n        self.initializer_range = initializer_range\n        self.pool_token = pool_token\n\n        super().__init__(**kwargs)"
  },
  {
    "path": "src/resources/models/glm-10B-chinese/modeling_glm.py",
    "content": "# coding=utf-8\n# Copyright 2022 shunxing1234 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\" PyTorch GLM model. \"\"\"\n\nimport math\n\nimport torch\nimport torch.utils.checkpoint\nimport torch.nn.functional as F\nfrom torch.nn import init, LayerNorm, Linear, CrossEntropyLoss\n\nfrom transformers.activations import gelu\nfrom transformers.utils import (\n    add_code_sample_docstrings,\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n)\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPastAndCrossAttentions,\n    ModelOutput,\n    SequenceClassifierOutput,\n)\n\nfrom transformers.modeling_utils import (\n    PreTrainedModel,\n)\nfrom .configuration_glm import GLMConfig\nfrom torch.nn.parameter import Parameter\n\n_CHECKPOINT_FOR_DOC = \"shunxing1234/GLM\"\n_CONFIG_FOR_DOC = \"GLMConfig\"\n_TOKENIZER_FOR_DOC = \"GLMTokenizer\"\n\nGLM_PRETRAINED_MODEL_ARCHIVE_LIST = [\n    \"shunxing1234/GLM\",\n    # See all GLM models at https://huggingface.co/models?filter=glm\n]\n\n\ndef unscaled_init_method(sigma):\n    \"\"\"Init method based on N(0, sigma).\"\"\"\n\n    def init_(tensor):\n        return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)\n\n    return init_\n\n\ndef scaled_init_method(mean, std, num_layers):\n    \"\"\"Init method based on N(0, sigma/sqrt(2*num_layers).\"\"\"\n    std = std / math.sqrt(2.0 * num_layers)\n\n    def init_(tensor):\n        return torch.nn.init.normal_(tensor, mean=mean, std=std)\n\n    return init_\n\n\ndef ensure_divisibility(numerator, denominator):\n    \"\"\"Ensure that numerator is divisible by the denominator.\"\"\"\n    assert numerator % denominator == 0, '{} is not divisible by {}'.format(\n        numerator, denominator)\n\n\ndef divide(numerator, denominator):\n    \"\"\"Ensure that numerator is divisible by the denominator and return\n    the division value.\"\"\"\n    ensure_divisibility(numerator, denominator)\n    return numerator // denominator\n\n\ndef split_tensor_along_last_dim(tensor, num_partitions,\n                                contiguous_split_chunks=False):\n    \"\"\"Split a tensor along its last dimension.\n    Arguments:\n        tensor: input tensor.\n        num_partitions: number of partitions to split the tensor\n        contiguous_split_chunks: If True, make each chunk contiguous\n                                 in memory.\n    \"\"\"\n    # Get the size and dimension.\n    last_dim = tensor.dim() - 1\n    last_dim_size = divide(tensor.size()[last_dim], num_partitions)\n    # Split.\n    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)\n    # Note: torch.split does not create contiguous tensors by default.\n    if contiguous_split_chunks:\n        return tuple(chunk.contiguous() for chunk in tensor_list)\n\n    return tensor_list\n\n\nclass MLP(torch.nn.Module):\n    \"\"\"MLP for GPT2.\n\n    MLP will take the input with h hidden state, project it to 4*h\n    hidden dimension, perform gelu transformation, and project the\n    state back into h hidden dimension. At the end, dropout is also\n    applied.\n\n    Arguments:\n        hidden_size: The hidden size of the self attention.\n        output_dropout_prob: dropout probability for the outputs\n                             after self attention and final output.\n        init_method: initialization method used for the weights. Note\n                     that all biases are initialized to zero and\n                     layernorm weight are initialized to one.\n        output_layer_init_method: output layer initialization. If None,\n                                  use `init_method`.\n    \"\"\"\n\n    def __init__(self, hidden_size, output_dropout_prob, init_method,\n                 output_layer_init_method=None):\n        super(MLP, self).__init__()\n        # Set output layer initialization if not provided.\n        if output_layer_init_method is None:\n            output_layer_init_method = init_method\n        # Project to 4h.\n        self.dense_h_to_4h = Linear(hidden_size, 4 * hidden_size)\n\n        # Project back to h.\n        self.dense_4h_to_h = Linear(\n            4 * hidden_size,\n            hidden_size)\n\n        self.dropout = torch.nn.Dropout(output_dropout_prob)\n\n    def forward(self, hidden_states):\n        # [b, s, 4hp]\n        intermediate_parallel = self.dense_h_to_4h(hidden_states)\n        intermediate_parallel = gelu(intermediate_parallel)\n\n        # [b, s, h]\n        output = self.dense_4h_to_h(intermediate_parallel)\n        output = self.dropout(output)\n        return output\n\n\nclass VocabEmbedding(torch.nn.Module):\n    \"\"\"Embedding parallelized in the vocabulary dimension.\n\n    This is mainly adapted from torch.nn.Embedding and all the default\n    values are kept.\n    Arguments:\n        num_embeddings: vocabulary size.\n        embedding_dim: size of hidden state.\n        init_method: method to initialize weights.\n    \"\"\"\n\n    def __init__(self, config):\n        super(VocabEmbedding, self).__init__()\n        # Keep the input dimensions.\n        self.num_embeddings = config.vocab_size\n        self.embedding_dim = config.hidden_size\n        # Set the detauls for compatibility.\n        self.padding_idx = None\n        self.max_norm = None\n        self.norm_type = 2.\n        self.scale_grad_by_freq = False\n        self.sparse = False\n        self._weight = None\n\n        self.vocab_start_index = 0\n        self.vocab_end_index = self.num_embeddings\n\n        # Allocate weights.\n        weight = torch.Tensor(self.num_embeddings, self.embedding_dim)\n        self.register_buffer(\"weight\", weight)\n\n        # And initialize.\n        init.xavier_normal_(self.weight)\n\n    def forward(self, input_):\n        # Get the embeddings.\n        output = F.embedding(input_, self.weight,\n                             self.padding_idx, self.max_norm,\n                             self.norm_type, self.scale_grad_by_freq,\n                             self.sparse)\n        return output\n\n\nclass PositionalEmbedding(torch.nn.Module):\n\n    def __init__(self, hidden_size):\n        super(PositionalEmbedding, self).__init__()\n\n        self.hidden_size = hidden_size\n\n        inv_freq = 1 / (10000 ** (torch.arange(0.0, hidden_size, 2.0) / hidden_size))\n        self.register_buffer('inv_freq', inv_freq)\n\n    def forward(self, pos_seq, bsz=None):\n        sinusoid_inp = torch.ger(pos_seq, self.inv_freq)\n        pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)\n\n        if bsz is not None:\n            return pos_emb[None, :, :].expand(bsz, -1, -1)\n        else:\n            return pos_emb[None, :, :]\n\n\nclass SelfAttention(torch.nn.Module):\n    \"\"\"self-attention layer for GLM.\n\n    Self-attention layer takes input with size [b, s, h] where b is\n    the batch size, s is the sequence lenght, and h is the hidden size\n    and creates output of the same size.\n    Arguments:\n        hidden_size: total hidden size of the layer (h).\n        num_attention_heads: number of attention heads (n). Note that we\n                             require n to be divisible by number of GPUs\n                             used to parallelize the model. Also, we\n                             require hidden size to be divisible by n.\n        attention_dropout_prob: dropout probability for the attention scores.\n        init_method: weight initialization.\n        output_layer_init_method: output layer initialization. If None, use\n                                  `init_method`.\n    We use the following notation:\n        h: hidden_size\n        n: num_attention_heads\n        p: number of partitions\n        np: n/p\n        hp: h/p\n        hn: h/n\n        b: batch size\n        s: sequence length\n    \"\"\"\n\n    def __init__(self, hidden_size, num_attention_heads,\n                 attention_dropout_prob, output_dropout_prob,\n                 init_method, output_layer_init_method=None,\n                 attention_scale=1.0):\n        super(SelfAttention, self).__init__()\n        # Set output layer initialization if not provided.\n        if output_layer_init_method is None:\n            output_layer_init_method = init_method\n        # Per attention head and per partition values.\n        self.hidden_size = hidden_size\n        self.hidden_size_per_attention_head = divide(hidden_size,\n                                                     num_attention_heads)\n\n        self.num_attention_heads = num_attention_heads\n        self.attention_scale = attention_scale\n        # Strided linear layer.\n        self.query_key_value = Linear(hidden_size, 3 * hidden_size)\n\n        # Dropout. Note that for a single iteration, this layer will generate\n        # different outputs on different number of parallel partitions but\n        # on average it should not be partition dependent.\n        self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)\n\n        # Output.\n        self.dense = Linear(hidden_size,\n                            hidden_size)\n        self.output_dropout = torch.nn.Dropout(output_dropout_prob)\n\n    def _transpose_for_scores(self, tensor):\n        \"\"\"Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with\n        size [b, np, s, hn].\n        \"\"\"\n        new_tensor_shape = tensor.size()[:-1] + \\\n                           (self.num_attention_heads,\n                            self.hidden_size_per_attention_head)\n        tensor = tensor.view(*new_tensor_shape)\n        return tensor.permute(0, 2, 1, 3)\n\n    def forward(self, hidden_states, ltor_mask, mem=None):\n        # hidden_states: [b, s, h]\n        # ltor_mask: [b,1,s,s]\n\n        # Attention heads. [b, s, hp]\n        query_length = hidden_states.size(1)\n        # self attention\n        if mem is None:\n            mixed_x_layer = self.query_key_value(hidden_states)\n            (mixed_query_layer,\n             mixed_key_layer,\n             mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)\n        else:\n            cat = torch.cat((mem, hidden_states), 1)\n            mixed_x_layer = self.query_key_value(cat)\n            (mixed_query_layer,\n             mixed_key_layer,\n             mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)\n            mixed_query_layer = mixed_query_layer[:, -query_length:]\n\n        # Reshape and transpose [b, np, s, hn]\n        query_layer = self._transpose_for_scores(mixed_query_layer)\n        key_layer = self._transpose_for_scores(mixed_key_layer)\n        value_layer = self._transpose_for_scores(mixed_value_layer)\n\n        if self.attention_scale > 1.0:\n            # Raw attention scores. [b, np, s, s]\n            attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_scale),\n                                            key_layer.transpose(-1, -2) / math.sqrt(\n                                                self.hidden_size_per_attention_head * self.attention_scale))\n        else:\n            attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2) / math.sqrt(\n                self.hidden_size_per_attention_head))\n\n        # Apply the left to right attention mask.\n        ltor_mask = ltor_mask.type_as(attention_scores)\n        attention_scores = torch.mul(attention_scores, ltor_mask)\n        if self.attention_scale > 1.0:\n            max_attention_scores = attention_scores.max(dim=-1, keepdim=True)[0]\n            attention_scores -= max_attention_scores\n            attention_scores *= self.attention_scale\n\n        attention_scores = attention_scores + (-65504.0) * (1.0 - ltor_mask)\n        # Attention probabilities. [b, np, s, s]\n        attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)\n        # This is actually dropping out entire tokens to attend to, which might\n        # seem a bit unusual, but is taken from the original Transformer paper.\n        # with get_cuda_rng_tracker().fork():\n        attention_probs = self.attention_dropout(attention_probs)\n\n        # Context layer.\n        # [b, np, s, hn]\n        context_layer = torch.matmul(attention_probs, value_layer)\n        # [b, s, np, hn]\n        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()\n        new_context_layer_shape = context_layer.size()[:-2] + \\\n                                  (self.hidden_size,)\n        # [b, s, hp]\n        context_layer = context_layer.view(*new_context_layer_shape)\n\n        # Output. [b, s, h]\n        output = self.dense(context_layer)\n        output = self.output_dropout(output)\n\n        return output\n\n\nclass GLMBlock(torch.nn.Module):\n    \"\"\"A single layer transformer for GLM.\n\n    We use the following notation:\n        h: hidden size\n        n: number of attention heads\n        b: batch size\n        s: sequence length\n    Transformore layer takes input with size [b, s, h] and returns an\n    output of the same size.\n\n    Arguments:\n        hidden_size: The hidden size of the self attention.\n        num_attention_heads: number of attention head in the self\n                             attention.\n        attention_dropout_prob: dropout probability of the attention\n                                score in self attention.\n        output_dropout_prob: dropout probability for the outputs\n                             after self attention and final output.\n        layernorm_epsilon: epsilon used in layernorm to avoid\n                           division by zero.\n        init_method: initialization method used for the weights. Note\n                     that all biases are initialized to zero and\n                     layernorm weight are initialized to one.\n        output_layer_init_method: output layers (attention output and\n                                  mlp output) initialization. If None,\n                                  use `init_method`.\n    \"\"\"\n\n    def __init__(self,\n                 hidden_size,\n                 num_attention_heads,\n                 attention_dropout_prob,\n                 output_dropout_prob,\n                 layernorm_epsilon,\n                 init_method,\n                 output_layer_init_method=None,\n                 attention_scale=1.0):\n        super(GLMBlock, self).__init__()\n        # Set output layer initialization if not provided.\n        if output_layer_init_method is None:\n            output_layer_init_method = init_method\n\n        # Layernorm on the input data.\n        self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)\n\n        # Self attention.\n        self.attention = SelfAttention(\n            hidden_size,\n            num_attention_heads,\n            attention_dropout_prob,\n            output_dropout_prob,\n            init_method,\n            output_layer_init_method=output_layer_init_method,\n            attention_scale=attention_scale)\n\n        # Layernorm on the input data.\n        self.post_attention_layernorm = LayerNorm(hidden_size,\n                                                  eps=layernorm_epsilon)\n\n        # MLP\n        self.mlp = MLP(\n            hidden_size,\n            output_dropout_prob,\n            init_method,\n            output_layer_init_method=output_layer_init_method)\n\n    def forward(self, hidden_states, ltor_mask, mem=None):\n        # hidden_states: [b, s, h]\n        # ltor_mask: [b,1, s,s]\n\n        # Layer norm at the begining of the transformer layer.\n        layernorm_output = self.input_layernorm(hidden_states)\n        mem = self.input_layernorm(mem) if mem is not None else None\n        # Self attention.\n        attention_output = self.attention(layernorm_output, ltor_mask, mem)\n        # Residual connection.\n        layernorm_input = hidden_states + attention_output\n        # Layer norm post the self attention.\n        layernorm_output = self.post_attention_layernorm(layernorm_input)\n        # MLP.\n        mlp_output = self.mlp(layernorm_output)\n        # Second residual connection.\n        output = layernorm_input + mlp_output\n\n        return output\n\n\nclass GLMStack(torch.nn.Module):\n    \"\"\"GLM transformer.\n\n    This module takes input from embedding layer and it's output can\n    be used directly by a logit layer. It consists of L (num-layers)\n    blocks of:\n        layer norm\n        self attention\n        residual connection\n        layer norm\n        mlp\n        residual connection\n    followed by a final layer norm.\n\n    Arguments:\n        num_layers: Number of transformer layers.\n        hidden_size: The hidden size of the self attention.\n        num_attention_heads: number of attention head in the self\n                             attention.\n        attention_dropout_prob: dropout probability of the attention\n                                score in self attention.\n        output_dropout_prob: dropout probability for the outputs\n                             after self attention and final output.\n        checkpoint_activations: if True, checkpoint activations.\n        checkpoint_num_layers: number of layers to checkpoint. This\n                               is basically the chunk size in checkpoitning.\n        layernorm_epsilon: epsilon used in layernorm to avoid\n                           division by zero.\n        init_method_std: standard deviation of the init method which has\n                         the form N(0, std).\n        use_scaled_init_for_output_weights: If Ture use 1/sqrt(2*num_layers)\n                                            scaling for the output weights (\n                                            output of self attention and mlp).\n    \"\"\"\n\n    def __init__(self,\n                 num_layers,\n                 hidden_size,\n                 num_attention_heads,\n                 max_sequence_length,\n                 embedding_dropout_prob,\n                 attention_dropout_prob,\n                 output_dropout_prob,\n                 checkpoint_activations,\n                 checkpoint_num_layers=1,\n                 layernorm_epsilon=1.0e-5,\n                 init_method_std=0.02,\n                 use_scaled_init_for_output_weights=True,\n                 block_position_encoding=False,\n                 attention_scale=1.0,\n                 ):\n        super(GLMStack, self).__init__()\n        self.hidden_size = hidden_size\n        # Store activation checkpoiting flag.\n        self.checkpoint_activations = checkpoint_activations\n        self.checkpoint_num_layers = checkpoint_num_layers\n\n        output_layer_init_method = None\n        if use_scaled_init_for_output_weights:\n            output_layer_init_method = scaled_init_method(0.0, init_method_std,\n                                                          num_layers)\n        # Embeddings dropout\n        self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)\n        self.block_position_encoding = block_position_encoding\n\n        # Position embedding (serial).\n        if block_position_encoding:\n            self.position_embeddings = torch.nn.Embedding(max_sequence_length + 1, hidden_size)\n            self.block_position_embeddings = torch.nn.Embedding(max_sequence_length + 1, hidden_size)\n            torch.nn.init.normal_(self.block_position_embeddings.weight, mean=0.0, std=init_method_std)\n        else:\n            self.position_embeddings = torch.nn.Embedding(max_sequence_length, hidden_size)\n        # Initialize the position embeddings.\n        torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)\n\n        def get_layer():\n\n            return GLMBlock(\n                hidden_size,\n                num_attention_heads,\n                attention_dropout_prob,\n                output_dropout_prob,\n                layernorm_epsilon,\n                unscaled_init_method(init_method_std),\n                output_layer_init_method=output_layer_init_method,\n                attention_scale=attention_scale)\n\n        # Transformer layers.\n        self.layers = torch.nn.ModuleList(\n            [get_layer() for _ in range(num_layers)])\n\n        # Final layer norm before output.\n        self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)\n\n    def forward(self, hidden_states, position_ids, attention_mask, memory_states=None):\n\n        batch_size, query_length = hidden_states.size()[:2]\n        memory_length = memory_states[0].size(1) if memory_states else 0\n        # attention mask is the beginning postion of B region, \\in [0, query_len)\n        is_scalar = torch.numel(attention_mask) == 1\n        is_sep = is_scalar or torch.numel(attention_mask) == batch_size\n        if is_sep:\n            sep = attention_mask.item() if is_scalar else attention_mask\n\n            # conventional transformer\n            def build_mask_matrix(seq_length, sep, memory_length=0):\n                m = hidden_states.new_ones((1, seq_length, seq_length))\n                m = torch.tril(m)\n                if is_scalar:\n                    m[0, :, :int(sep)] = 1\n                else:\n                    m = m.expand(batch_size, -1, -1)\n                    ids = torch.arange(seq_length, device=sep.device, dtype=sep.dtype).view(1, -1)\n                    mask = ids < sep.view(-1, 1)\n                    m = m.masked_fill(mask.unsqueeze(1).expand_as(m), 1)\n                if memory_length > 0:\n                    m = m.expand(batch_size, -1, -1)\n                    m = torch.cat((hidden_states.new_ones((batch_size, seq_length, memory_length)), m), dim=2)\n                m = m.unsqueeze(1)\n                return m\n\n            attention_mask = build_mask_matrix(query_length, sep, memory_length=memory_length)\n        else:\n            if attention_mask.dim() == 2:\n                attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)\n            attention_mask = attention_mask[:, :, :, -query_length - memory_length:]\n\n        if self.block_position_encoding:\n            position_ids, block_position_ids = position_ids[:, 0], position_ids[:, 1]\n        position_embeddings = self.position_embeddings(position_ids)\n\n        hidden_states = hidden_states + position_embeddings\n        if self.block_position_encoding:\n            block_position_embeddings = self.block_position_embeddings(block_position_ids)\n            hidden_states = hidden_states + block_position_embeddings\n        hidden_states = self.embedding_dropout(hidden_states)\n\n        def check_detach(_hidden_states):\n            return _hidden_states.detach()\n\n        mem_layers = [check_detach(hidden_states)]\n\n        for i, layer in enumerate(self.layers):\n\n            args = [hidden_states, attention_mask]\n\n            def create_custom_forward(module):\n                def custom_forward(*inputs):\n                    # None for past_key_value\n                    return module(*inputs)\n\n                return custom_forward\n\n            mem_i = memory_states[i] if memory_states else None\n\n            if self.checkpoint_activations:\n                hidden_states = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(layer),\n                    hidden_states,\n                    mem=mem_i,\n                )\n            else:\n                hidden_states = layer(*args, mem=mem_i)\n            mem_layers.append(check_detach(hidden_states))\n\n        # Final layer norm.\n        output = self.final_layernorm(hidden_states)\n        mem_layers = self.update_mems(mem_layers, memory_states)\n        return (output, mem_layers)\n\n    def update_mems(self, hiddens, mems):\n        memory_length = mems[0].size(1) if mems else 0\n        query_length = hiddens[0].size(1)\n        new_memory_length = memory_length + query_length\n\n        new_mems = []\n        # with torch.no_grad():\n        for i in range(len(hiddens)):\n            if new_memory_length <= query_length:\n                new_mems.append(hiddens[i][:, -new_memory_length:])\n            else:\n                new_mems.append(torch.cat((mems[i][:, -new_memory_length + query_length:], hiddens[i]), dim=1))\n        return new_mems\n\n\nclass GLMPreTrainedModel(PreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and\n    a simple interface for downloading and loading pretrained models.\n    \"\"\"\n\n    config_class = GLMConfig\n    base_model_prefix = \"glm\"\n    supports_gradient_checkpointing = True\n    _keys_to_ignore_on_load_missing = [r\"position_ids\"]\n\n    def _init_weights(self, module):\n        \"\"\" Initialize the weights \"\"\"\n        if isinstance(module, torch.nn.Linear):\n            # Slightly different from the TF version which uses truncated_normal for initialization\n            # cf https://github.com/pytorch/pytorch/pull/5617\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, torch.nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n        elif isinstance(module, torch.nn.LayerNorm):\n            module.bias.data.zero_()\n            module.weight.data.fill_(1.0)\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, GLMModel):\n            module.gradient_checkpointing = value\n\n\nGLM_START_DOCSTRING = r\"\"\"\n    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general\n    usage and behavior.\n\n    Parameters:\n        config ([`~GLMConfig`]): Model configuration class with all the parameters of the model.\n            Initializing with a config file does not load the weights associated with the model, only the configuration.\n            Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\nGLM_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `({0})`):\n            Indices of input sequence tokens in the vocabulary.\n\n            Indices can be obtained using [`GLMTokenizer`].\n            See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):\n            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`:\n\n            - 0 corresponds to a *sentence A* token,\n            - 1 corresponds to a *sentence B* token.\n\n            [What are token type IDs?](../glossary#token-type-ids)\n        position_ids (`torch.LongTensor` of shape `({0})`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings.\n            Selected in the range `[0, config.max_position_embeddings - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):\n            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n\n        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.\n            This is useful if you want more control over how to convert *input_ids* indices into associated vectors\n            than the model's internal embedding lookup matrix.\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare GLM Model transformer outputting raw hidden-states without any specific head on top.\",\n    GLM_START_DOCSTRING,\n)\nclass GLMModel(GLMPreTrainedModel):\n    \"\"\"\n\n    The model can behave as an encoder (with only self-attention) as well\n    as a decoder, in which case a layer of cross-attention is added between\n    the self-attention layers, following the architecture described in [Attention is\n    all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani,\n    Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.\n\n    To behave as an decoder the model needs to be initialized with the\n    `is_decoder` argument of the configuration set to `True`.\n    To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder`\n    argument and `add_cross_attention` set to `True`; an\n    `encoder_hidden_states` is then expected as an input to the forward pass.\n    \"\"\"\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.config = config\n        self.output_predict = config.output_predict\n        # Word embeddings (parallel).\n        self.word_embeddings = VocabEmbedding(config)\n\n        # Transformer\n        self.transformer = GLMStack(config.num_layers,\n                                    config.hidden_size,\n                                    config.num_attention_heads,\n                                    config.max_sequence_length,\n                                    config.embedding_dropout_prob,\n                                    config.attention_dropout_prob,\n                                    config.output_dropout_prob,\n                                    config.checkpoint_activations,\n                                    config.checkpoint_num_layers,\n                                    attention_scale=config.attention_scale,\n                                    block_position_encoding=config.block_position_encoding)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    @add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING.format(\"batch_size, sequence_length\"))\n    @add_code_sample_docstrings(\n        processor_class=_TOKENIZER_FOR_DOC,\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=BaseModelOutputWithPastAndCrossAttentions,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n            self,\n            input_ids=None,\n            position_ids=None,\n            attention_mask=None,\n            mems=None,\n            **kwargs\n    ):\n        batch_size = input_ids.size(0)\n        words_embeddings = self.word_embeddings(input_ids)\n        embeddings = words_embeddings\n\n        device = input_ids.device\n        input_shape = input_ids.size()\n\n        if position_ids is None:\n            position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device)\n            block_position_ids = torch.zeros(input_shape[-1], dtype=torch.long, device=device)\n            position_ids = torch.stack((position_ids, block_position_ids), dim=0).unsqueeze(0)\n        if attention_mask is None:\n            attention_mask = torch.zeros(batch_size)\n        # Transformer.\n        transformer_output = self.transformer(embeddings, position_ids, attention_mask, mems)\n        last_hidden_states, mems = transformer_output\n        logits = None\n        if self.output_predict:\n            logits = F.linear(last_hidden_states, self.word_embeddings.weight)\n\n        return ModelOutput(\n            last_hidden_states=last_hidden_states,\n            logits=logits,\n            mems=mems,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"GLM Model transformer for multiple choice classification\"\"\",\n    GLM_START_DOCSTRING\n)\nclass GLMForMultipleChoice(GLMPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.glm = GLMModel(config)\n        self.post_init()\n\n    def forward(\n            self,\n            input_ids=None,\n            position_ids=None,\n            attention_mask=None,\n            choice_ids=None,\n            choice_indices=None,\n            labels=None,\n            mems=None,\n            **kwargs\n    ):\n        model_output = self.glm(input_ids, position_ids, attention_mask, mems=mems, **kwargs)\n        lm_logits = model_output.logits\n        log_probs = []\n        for output, choices, choice_index in zip(F.log_softmax(lm_logits, dim=-1), choice_ids, choice_indices):\n            log_probs_single = []\n            for choice, choice_target_id in zip(choices, choice_index):\n                tmp = output[choice_target_id, choice]\n                log_probs_single.append(tmp.sum())\n            log_probs.append(torch.stack(log_probs_single))\n        log_probs = torch.stack(log_probs)\n        loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss()\n            loss = loss_fct(log_probs, labels)\n        return ModelOutput(\n            loss=loss,\n            logits=log_probs,\n            lm_logits=lm_logits,\n            mems=model_output.mems\n        )\n\n@add_start_docstrings(\n    \"\"\"GLM Model transformer with a `language modeling` head on top\"\"\",\n    GLM_START_DOCSTRING,\n)\nclass GLMForConditionalGeneration(GLMPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.glm = GLMModel(config)\n        self.post_init()\n\n    def _reorder_cache(self, past, beam_idx):\n        # if decoder past is not included in output\n        # speedy decoding is disabled and no need to reorder\n        if past is None:\n            return past\n        reordered_decoder_past = ()\n        for layer_past_states in past:\n            # get the correct batch idx from layer past batch dim\n            reordered_decoder_past = reordered_decoder_past + (\n                layer_past_states.index_select(0, beam_idx.to(layer_past_states.device)),)\n        return reordered_decoder_past\n\n    def prepare_inputs_for_generation(self, input_ids, past=None, position_ids=None, generation_attention_mask=None,\n                                      **kwargs):\n        # only last token for inputs_ids if past is defined in kwargs\n        attention_mask = generation_attention_mask\n        seq_length = input_ids.shape[1]\n        if past:\n            if position_ids is not None:\n                position_ids = position_ids[:, :, seq_length - 1].unsqueeze(-1)\n            if attention_mask is not None:\n                attention_mask = attention_mask[:, :, seq_length - 1, :seq_length].unsqueeze(-2)\n            input_ids = input_ids[:, -1].unsqueeze(-1)\n        else:\n            if position_ids is not None:\n                position_ids = position_ids[:, :, :seq_length]\n            if attention_mask is not None:\n                attention_mask = attention_mask[:, :, :seq_length, :seq_length]\n        if position_ids is not None and input_ids.size(0) > position_ids.size(0):\n            batch_size = position_ids.size(0)\n            num_beams = input_ids.size(0) // batch_size\n            position_ids = position_ids.unsqueeze(1).expand(-1, num_beams, -1, -1)\n            position_ids = position_ids.reshape(batch_size * num_beams, *position_ids.shape[-2:])\n        if attention_mask is not None and input_ids.size(0) > attention_mask.size(0):\n            batch_size = attention_mask.size(0)\n            num_beams = input_ids.size(0) // batch_size\n            attention_mask = attention_mask.unsqueeze(1).expand(-1, num_beams, -1, -1, -1)\n            attention_mask = attention_mask.reshape(batch_size * num_beams, *attention_mask.shape[-3:])\n        return {\n            \"input_ids\": input_ids,\n            \"position_ids\": position_ids,\n            \"attention_mask\": attention_mask,\n            \"mems\": past,\n        }\n\n    def forward(\n            self,\n            input_ids=None,\n            position_ids=None,\n            attention_mask=None,\n            labels=None,\n            mems=None,\n            **kwargs\n    ):\n        model_output = self.glm(input_ids, position_ids, attention_mask, mems=mems, **kwargs)\n        lm_logits = model_output.logits\n        loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss(ignore_index=self.config.pad_token_id)\n            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))\n        return ModelOutput(\n            loss=loss,\n            logits=lm_logits,\n            mems=model_output.mems\n        )\n\n\n@add_start_docstrings(\n    \"\"\"GLM Model transformer with a sequence classification/regression head on top (a linear layer on top of\n    the pooled output) e.g. for GLUE tasks. \"\"\",\n    GLM_START_DOCSTRING,\n)\nclass GLMForSequenceClassification(GLMPreTrainedModel):\n    def __init__(self, config: GLMConfig, hidden_dropout=None, num_class=1):\n        super().__init__(config)\n        self.pool_token = config.pool_token\n        self.glm = GLMModel(config)\n        self.glm.output_predict = False\n        self.num_class = num_class\n        # Multi-choice head.\n        self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)\n        classifier_dropout = (\n            config.classifier_dropout if config.classifier_dropout is not None else config.output_dropout_prob\n        )\n        self.dropout = torch.nn.Dropout(classifier_dropout)\n        self.out_proj = torch.nn.Linear(config.hidden_size, config.num_labels)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    @add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING.format(\"batch_size, sequence_length\"))\n    @add_code_sample_docstrings(\n        processor_class=_TOKENIZER_FOR_DOC,\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=SequenceClassifierOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(self,\n                input_ids=None,\n                position_ids=None,\n                attention_mask=None,\n                labels=None):\n\n        num_choices = None\n\n        if len(input_ids.shape) == 3:\n            batch_size, num_choices = input_ids.shape[:2]\n            input_ids = input_ids.reshape(-1, input_ids.size(-1))\n            attention_mask = attention_mask.reshape(-1, *attention_mask.size()[2:])\n            position_ids = position_ids.reshape(-1, *position_ids.size()[2:])\n        model_out = self.glm(input_ids, position_ids, attention_mask)\n        outputs, mems = model_out.last_hidden_states, model_out.mems\n\n        output = outputs[:, 0, :]\n        output = self.dropout(output)\n        output = torch.tanh(self.dense(output))\n        output = self.dropout(output)\n        logits = self.out_proj(output)\n        if num_choices is not None:\n            logits = logits.view(-1, num_choices)\n        loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss()\n            loss = loss_fct(logits, labels)\n        # loss = F.cross_entropy(logits.contiguous().float(), labels.long())\n        return SequenceClassifierOutput(loss=loss,\n                                        logits=logits,\n                                        hidden_states=outputs)"
  },
  {
    "path": "src/resources/models/glm-10B-chinese/tokenization_glm.py",
    "content": "import os\nfrom typing import Optional, Tuple, List, Union\nfrom shutil import copyfile\nimport torch\n\nfrom transformers import PreTrainedTokenizer, RobertaTokenizer, GPT2Tokenizer, BertTokenizer\nfrom transformers.utils import logging\nfrom transformers.tokenization_utils_base import BatchEncoding\nfrom transformers.models.auto.tokenization_auto import get_tokenizer_config\n# from transformers.utils import torch_required\nfrom transformers.utils.generic import _is_torch_device\nimport sentencepiece as spm\n\nlogger = logging.get_logger(__name__)\n\n\nclass GLMBatchEncoding(BatchEncoding):\n    # @torch_required\n    def to(self, device: Union[str, \"torch.device\"]) -> \"BatchEncoding\":\n        \"\"\"\n        Send all values to device by calling `v.to(device)` (PyTorch only).\n        Args:\n            device (`str` or `torch.device`): The device to put the tensors on.\n        Returns:\n            [`BatchEncoding`]: The same instance after modification.\n        \"\"\"\n\n        # This check catches things like APEX blindly calling \"to\" on all inputs to a module\n        # Otherwise it passes the casts down and casts the LongTensor containing the token idxs\n        # into a HalfTensor\n        if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int):\n            self.data = {k: v.to(device=device) if torch.is_tensor(v) else v for k, v in self.data.items()}\n        else:\n            logger.warning(f\"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.\")\n        return self\n\n\nclass GLMTokenizerMixin:\n    @property\n    def sop_token(self) -> Optional[str]:\n        return \"<|startofpiece|>\"\n\n    @property\n    def sop_token_id(self) -> Optional[int]:\n        \"\"\"\n        `Optional[int]`: Id of the start token in the vocabulary, used when training a model with autoregressive blank filling.\n        \"\"\"\n        return self.convert_tokens_to_ids(self.sop_token)\n\n    @property\n    def eop_token(self) -> Optional[str]:\n        return \"<|endofpiece|>\"\n\n    @property\n    def eop_token_id(self) -> Optional[int]:\n        \"\"\"\n        `Optional[int]`: Id of the end token in the vocabulary, used when training a model with autoregressive blank filling.\n        \"\"\"\n        return self.convert_tokens_to_ids(self.eop_token)\n\n    @property\n    def gmask_token_id(self) -> int:\n        return self.convert_tokens_to_ids(\"[gMASK]\")\n\n    @property\n    def smask_token_id(self) -> int:\n        return self.convert_tokens_to_ids(\"[sMASK]\")\n\n    @property\n    def mask_token_ids(self):\n        return [self.mask_token_id, self.smask_token_id, self.gmask_token_id]\n\n    def _build_input_for_multiple_choice(self, context, choices):\n        context_id = context[\"input_ids\"]\n        if torch.is_tensor(context_id):\n            context_id = context_id.tolist()\n\n        division = len(context_id)\n        mask_position = context_id.index(self.mask_token_id)\n\n        token = torch.tensor(context_id, dtype=torch.long)\n        attention_mask = [context[\"attention_mask\"].expand(division, -1)]\n        position_id = torch.arange(division, dtype=torch.long)\n        block_position_id = torch.zeros(division, dtype=torch.long)\n\n        choice_ids, choice_indices = [], []\n\n        for choice_str in choices:\n            choice = torch.tensor(self(choice_str, add_special_tokens=False, padding=False)['input_ids'],\n                                  dtype=torch.long)\n            choice_ids.append(choice)\n            choice_indices.append(torch.arange(len(token), len(token) + len(choice), dtype=torch.long))\n            attention_mask.append(torch.tril(torch.ones((len(choice), len(choice)), dtype=torch.long)))\n\n            token = torch.cat((token, torch.tensor([self.sop_token_id], dtype=torch.long), choice[:-1]))\n            position_id = torch.cat((position_id, torch.tensor([mask_position] * len(choice), dtype=torch.long)))\n            block_position_id = torch.cat((block_position_id, torch.arange(1, 1 + len(choice), dtype=torch.long)))\n\n        attention_mask = torch.block_diag(*attention_mask)\n        attention_mask[division:, :division] = context[\"attention_mask\"].unsqueeze(0)\n\n        return {\n            \"input_ids\": token,\n            \"position_ids\": torch.stack((position_id, block_position_id)),\n            \"attention_mask\": attention_mask,\n            \"choice_ids\": choice_ids,\n            \"choice_indices\": choice_indices\n        }\n\n    def _pad_batch(self, tokens, position_ids, attention_mask, max_seq_length):\n        pad_length = max_seq_length - len(tokens)\n        attention_mask = torch.nn.functional.pad(\n            attention_mask,\n            (0, pad_length, 0, pad_length),\n            mode=\"constant\",\n            value=0,\n        )\n        tokens = torch.cat((tokens, torch.zeros(pad_length, dtype=torch.long)))\n        position_ids = torch.cat((position_ids, position_ids[..., -1:].expand(-1, pad_length)), dim=-1)\n        return tokens, position_ids, attention_mask\n\n    def _collate(self, samples):\n        TILE = 1\n        length_to_pad = (max(map(lambda spl: len(spl[\"input_ids\"]), samples)) + TILE - 1) // TILE * TILE\n\n        token_batch, position_id_batch, attention_mask_batch = [], [], []\n        choices_batch, choice_target_ids_batch = [], []\n\n        for sample in samples:\n            token, position_id, attention_mask = self._pad_batch(\n                sample[\"input_ids\"], sample[\"position_ids\"], sample[\"attention_mask\"], length_to_pad\n            )\n            token_batch.append(token)\n            position_id_batch.append(position_id)\n            attention_mask_batch.append(attention_mask)\n            choices_batch.append(sample[\"choice_ids\"])\n            choice_target_ids_batch.append(sample[\"choice_indices\"])\n        return {\n            \"input_ids\": torch.stack(token_batch),\n            \"position_ids\": torch.stack(position_id_batch),\n            \"attention_mask\": torch.stack(attention_mask_batch).unsqueeze(1),\n            \"choice_ids\": choices_batch,\n            \"choice_indices\": choice_target_ids_batch,\n        }\n\n    def build_inputs_for_multiple_choice(self, model_input: BatchEncoding, choices, max_length=None):\n        samples = [{key: value[i] for key, value in model_input.items()} for i in range(len(model_input[\"input_ids\"]))]\n        samples = [self._build_input_for_multiple_choice(sample, choice) for sample, choice in\n                   zip(samples, choices)]\n        inputs = self._collate(samples)\n        return GLMBatchEncoding(inputs)\n\n    def build_inputs_for_generation(self, model_input: BatchEncoding, max_gen_length=512, targets=None, padding=False):\n        mask_ids = self.mask_token_ids\n        input_ids = model_input.input_ids\n        batch_size, seq_length = input_ids.shape[:2]\n        position_id, block_position_id = list(range(seq_length)), [0 for _ in range(seq_length)]\n        position_ids, block_position_ids = [], []\n        labels = None\n        if targets is not None:\n            is_batched = isinstance(targets, (list, tuple))\n            targets = self(targets, add_special_tokens=False, padding=False).input_ids\n            if not is_batched:\n                targets = [targets]\n            assert len(targets) == len(input_ids)\n            targets = [target[:(max_gen_length-1)] + [self.eop_token_id] for target in targets]\n            if not padding:\n                max_gen_length = max(map(len, targets))\n            targets = [[self.sop_token_id] + target for target in targets]\n            labels = [target[1:] for target in targets]\n            targets = [target + [self.pad_token_id] * (max_gen_length + 1 - len(target)) for target in targets]\n            labels = [label + [self.pad_token_id] * (max_gen_length - len(label)) for label in labels]\n            targets = torch.tensor(targets, dtype=input_ids.dtype, device=input_ids.device)\n            labels = torch.tensor(labels, dtype=input_ids.dtype, device=input_ids.device)\n            labels = torch.cat((input_ids.new_full((batch_size, seq_length), self.pad_token_id), labels), dim=1)\n        for i in range(batch_size):\n            mask_positions = []\n            for mask_id in mask_ids:\n                mask_positions += (input_ids[i] == mask_id).nonzero(as_tuple=True)[0].tolist()\n            if not mask_positions:\n                raise ValueError(\"Cannot find mask token in the input\")\n            mask_positions.sort()\n            mask_pos = mask_positions[0]\n            position_ids.append(position_id + [mask_pos] * max_gen_length)\n            block_position_ids.append(block_position_id + list(range(1, max_gen_length + 1)))\n        position_ids = torch.tensor(position_ids, dtype=input_ids.dtype, device=input_ids.device)\n        block_position_ids = torch.tensor(block_position_ids, dtype=input_ids.dtype, device=input_ids.device)\n        position_ids = torch.stack((position_ids, block_position_ids), dim=1)\n        attention_mask = model_input.attention_mask\n        attention_mask = attention_mask.unsqueeze(1).expand(-1, seq_length + max_gen_length, -1)\n        generation_attention_mask = torch.cat([attention_mask.new_zeros((seq_length, max_gen_length)),\n                                               torch.tril(attention_mask.new_ones((max_gen_length, max_gen_length)))],\n                                              dim=0).unsqueeze(0).expand(batch_size, -1, -1)\n        attention_mask = torch.cat((attention_mask, generation_attention_mask), dim=2)\n        attention_mask = attention_mask.unsqueeze(1)\n        if targets is None:\n            input_ids = torch.cat((input_ids, input_ids.new_full((batch_size, 1), self.sop_token_id)), dim=-1)\n        else:\n            input_ids = torch.cat((input_ids, targets[:, :-1]), dim=1)\n        batch = {\"input_ids\": input_ids, \"position_ids\": position_ids}\n        if labels is None:\n            batch[\"generation_attention_mask\"] = attention_mask\n        else:\n            batch[\"attention_mask\"] = attention_mask\n            batch[\"labels\"] = labels\n        return BatchEncoding(batch)\n\n\nclass GLMRobertaTokenizer(RobertaTokenizer, GLMTokenizerMixin):\n    model_input_names = [\"input_ids\", \"position_ids\", \"attention_mask\"]\n    truncation_side: str = \"left\"\n\n    @property\n    def gmask_token_id(self) -> int:\n        raise NotImplementedError(\"The model doesn't support gMASK\")\n\n    @property\n    def smask_token_id(self) -> int:\n        raise NotImplementedError(\"The model doesn't support sMASK\")\n\n    @property\n    def mask_token_ids(self):\n        return [self.mask_token_id]\n\n\nclass GLMChineseTokenizer(PreTrainedTokenizer, GLMTokenizerMixin):\n    vocab_files_names = {\"vocab_file\": \"cog-pretrain.model\"}\n    truncation_side: str = \"left\"\n\n    def __init__(self, vocab_file, **kwargs):\n        super().__init__(**kwargs)\n        self.vocab_file = vocab_file\n        self.sp_model = spm.SentencePieceProcessor()\n        self.sp_model.Load(vocab_file)\n\n    @property\n    def vocab_size(self):\n        return len(self.sp_model)\n\n    def get_vocab(self):\n        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}\n        vocab.update(self.added_tokens_encoder)\n        return vocab\n\n    def _tokenize(self, text, **kwargs):\n        return self.sp_model.encode(text, out_type=str)\n\n    def _convert_token_to_id(self, token):\n        \"\"\"Converts a token (str) in an id using the vocab.\"\"\"\n        return self.sp_model.PieceToId(token)\n\n    def _convert_id_to_token(self, index):\n        \"\"\"Converts an index (integer) in a token (str) using the vocab.\"\"\"\n        return self.sp_model.IdToPiece(index)\n\n    def convert_tokens_to_string(self, tokens):\n        return self.sp_model.decode(tokens)\n\n    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:\n        if not os.path.isdir(save_directory):\n            logger.error(f\"Vocabulary path ({save_directory}) should be a directory\")\n            return\n        out_vocab_file = os.path.join(\n            save_directory, (filename_prefix + \"-\" if filename_prefix else \"\") + self.vocab_files_names[\"vocab_file\"]\n        )\n\n        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):\n            copyfile(self.vocab_file, out_vocab_file)\n        elif not os.path.isfile(self.vocab_file):\n            with open(out_vocab_file, \"wb\") as fi:\n                content_spiece_model = self.sp_model.serialized_model_proto()\n                fi.write(content_spiece_model)\n\n        return (out_vocab_file,)\n\n    def build_inputs_with_special_tokens(\n            self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None\n    ) -> List[int]:\n        \"\"\"\n        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and\n        adding special tokens. A BERT sequence has the following format:\n        - single sequence: ``[CLS] X [SEP]``\n        - pair of sequences: ``[CLS] A [SEP] B [SEP]``\n        Args:\n            token_ids_0 (:obj:`List[int]`):\n                List of IDs to which the special tokens will be added.\n            token_ids_1 (:obj:`List[int]`, `optional`):\n                Optional second list of IDs for sequence pairs.\n        Returns:\n            :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.\n        \"\"\"\n        #assert token_ids_1 is None\n        #cls = [self.cls_token_id]\n        #eos = [self.eos_token_id]\n        #return cls + token_ids_0 + eos\n        cls = [self.cls_token_id]\n        eos = [self.eos_token_id]\n        #eop = [self.eop_token_id]\n        #mask = [self.mask_token_id]\n        sep = [self.sep_token_id]\n        #token_ids_0 = cls + token_ids_0 + mask + eos\n        if token_ids_1 is None:\n            return cls + token_ids_0 + eos\n        else:\n            return  cls + token_ids_0 + sep + token_ids_1 + eos\n\n\nclass GLMGPT2Tokenizer(GPT2Tokenizer, GLMTokenizerMixin):\n    model_input_names = [\"input_ids\", \"position_ids\", \"attention_mask\"]\n    truncation_side: str = \"left\"\n\n    def build_inputs_with_special_tokens(\n            self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None\n    ) -> List[int]:\n        \"\"\"\n        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and\n        adding special tokens. A BERT sequence has the following format:\n        - single sequence: ``[CLS] X [SEP]``\n        - pair of sequences: ``[CLS] A [SEP] B [SEP]``\n        Args:\n            token_ids_0 (:obj:`List[int]`):\n                List of IDs to which the special tokens will be added.\n            token_ids_1 (:obj:`List[int]`, `optional`):\n                Optional second list of IDs for sequence pairs.\n        Returns:\n            :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.\n        \"\"\"\n        assert token_ids_1 is None\n        cls = [self.cls_token_id]\n        eos = [self.eos_token_id]\n        return cls + token_ids_0 + eos\n\n\nclass GLMBertTokenizer(BertTokenizer, GLMTokenizerMixin):\n    model_input_names = [\"input_ids\", \"position_ids\", \"attention_mask\"]\n    truncation_side: str = \"left\"\n\n    @property\n    def gmask_token_id(self) -> int:\n        raise NotImplementedError(\"The model doesn't support gMASK\")\n\n    @property\n    def smask_token_id(self) -> int:\n        raise NotImplementedError(\"The model doesn't support sMASK\")\n\n    @property\n    def mask_token_ids(self):\n        return [self.mask_token_id]\n\n\nclass GLMTokenizer:\n    @classmethod\n    def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):\n        tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)\n        config_tokenizer_class = tokenizer_config.get(\"tokenizer_class\")\n        if config_tokenizer_class == \"GLMRobertaTokenizer\":\n            tokenizer_class = GLMRobertaTokenizer\n        elif config_tokenizer_class == \"GLMChineseTokenizer\":\n            tokenizer_class = GLMChineseTokenizer\n        elif config_tokenizer_class == \"GLMGPT2Tokenizer\":\n            tokenizer_class = GLMGPT2Tokenizer\n        elif config_tokenizer_class == \"GLMBertTokenizer\":\n            tokenizer_class = GLMBertTokenizer\n        else:\n            raise NotImplementedError(\"Not implemented tokenizer type:\", config_tokenizer_class)\n        return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)"
  },
  {
    "path": "src/resources/models/glm-10B-chinese/tokenizer_config.json",
    "content": "{\n  \"name_or_path\": \"BAAI/glm-large-chinese\",\n  \"eos_token\": \"<|endoftext|>\",\n  \"pad_token\": \"<|endoftext|>\",\n  \"cls_token\": \"[CLS]\",\n  \"mask_token\": \"[MASK]\",\n  \"unk_token\": \"[UNK]\",\n  \"sep_token\": \"[SEP]\",\n  \"additional_special_tokens\": [\"<|startofpiece|>\", \"<|endofpiece|>\", \"[gMASK]\", \"[sMASK]\"],\n  \"add_prefix_space\": false,\n  \"tokenizer_class\": \"GLMChineseTokenizer\",\n  \"use_fast\": false,\n  \"auto_map\": {\n    \"AutoTokenizer\": [\n      \"tokenization_glm.GLMChineseTokenizer\",\n      null\n      ]\n  }\n}"
  },
  {
    "path": "src/resources/models/glm-350M-chinese/README.md",
    "content": "---\nlanguage:\n- zh\ntags:\n- glm\n- chatgpt\n---\n\nLink to github: [here](https://github.com/sunzeyeah/RLHF)\n\n---\n\n本仓库由[THUDM/glm-large-chinese](https://huggingface.co/THUDM/glm-large-chinese) fork而来，原仓库实现了PyTorch版本的GLM模型，该模型有3.5亿参数量，模型权重文件以FP32格式存储。\n\n本仓库在原始代码的基础上进行了部分调整，以支持ChatGPT训练pipeline，具体实现可参考：[sunzeyeah/RLHF](https://github.com/sunzeyeah/RLHF).\n\nThis repository is forked from [THUDM/glm-large-chinese](https://huggingface.co/THUDM/glm-large-chinese) that contains PyTorch implementation of GLM model with 350 million parameters pretrained weights (FP32 precision).\n\nIt is slightly different from the original GLM implementation to support the ChatGPT training pipeline in this github repo: [sunzeyeah/RLHF](https://github.com/sunzeyeah/RLHF).\n\n---\n\n# Model description\nGLM is a General Language Model pretrained with an autoregressive blank-filling objective and can be finetuned on various natural language understanding and generation tasks.\n\nPlease refer to our paper for a detailed description of GLM:\n\n[GLM: General Language Model Pretraining with Autoregressive Blank Infilling](https://arxiv.org/abs/2103.10360) (ACL 2022)\n\nZhengxiao Du*, Yujie Qian*, Xiao Liu, Ming Ding, Jiezhong Qiu, Zhilin Yang, Jie Tang (*: equal contribution)\n\nFind more examples in our [Github repo](https://github.com/THUDM/GLM).\n\n`glm-10b-chinese` is pretrained on the [WuDaoCorpora](https://www.sciencedirect.com/science/article/pii/S2666651021000152) dataset. It has 48 transformer layers, with hidden size 4096 and 64 attention heads in each layer. The model is pretrained with autoregressive blank filling objectives designed for natural language understanding, seq2seq, and language modeling.\n\n---\n\n# Usage (Text Generation)\n```python\nfrom transformers import AutoTokenizer, AutoModelForSeq2SeqLM\ntokenizer = AutoTokenizer.from_pretrained(\"sunzeyeah/glm-350M-chinese\", trust_remote_code=True)\nmodel = AutoModelForSeq2SeqLM.from_pretrained(\"sunzeyeah/glm-350M-chinese\", trust_remote_code=True)\nmodel = model.half().cuda()\n\nmax_length = 512\nprompt = \"我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点\"\nprefix = \"回答：\"\nencoded_prompt = tokenizer(prompt, prefix + tokenizer.mask_token)\nprompt_length = len(encoded_prompt['input_ids'])\nencoded_dict = tokenizer(prompt, prefix + tokenizer.mask_token,\n                         max_length=min(prompt_length, max_length),\n                         truncation=\"only_first\",\n                         return_tensors=\"pt\",\n                         return_token_type_ids=False)\nmax_gen_length = max_length - encoded_dict['input_ids'].shape[1]\ninputs = tokenizer.build_inputs_for_generation(encoded_dict, max_gen_length=max_gen_length, padding=True)\ninputs = inputs.cuda()\noutputs = model.generate(**inputs,\n                         max_new_tokens=max_gen_length,\n                         eos_token_id=tokenizer.eop_token_id,\n                         pad_token_id=tokenizer.pad_token_id,\n                         do_sample=False,\n                         num_return_sequences=1,\n                         top_p=0.8,\n                         temperature=1.0)\nresults = tokenizer.batch_decode(outputs, skip_special_tokens=True)\nprint(results)\n```\n\n"
  },
  {
    "path": "src/resources/models/glm-350M-chinese/added_tokens.json",
    "content": "{\n  \"<|endoftext|>\": 50000,\n  \"[SEP]\": 50001,\n  \"[CLS]\": 50002,\n  \"[MASK]\": 50003,\n  \"[UNUSED1]\": 50004,\n  \"[UNUSED2]\": 50005,\n  \"<|startofpiece|>\": 50006,\n  \"<|endofpiece|>\": 50007,\n  \"[sMASK]\": 50008,\n  \"[gMASK]\": 50009\n}"
  },
  {
    "path": "src/resources/models/glm-350M-chinese/config.json",
    "content": "{\n  \"model_name_or_path\": \"sunzeyeah/glm-350M-chinese\",\n  \"architectures\": [\n    \"GLMModel\"\n  ],\n  \"attention_dropout_prob\": 0.1,\n  \"attention_scale\": 1.0,\n  \"auto_map\": {\n    \"AutoConfig\": \"configuration_glm.GLMConfig\",\n    \"AutoModel\": \"modeling_glm.GLMModel\",\n    \"AutoModelForSeq2SeqLM\": \"modeling_glm.GLMForConditionalGeneration\",\n    \"AutoModelForMultipleChoice\": \"modeling_glm.GLMForMultipleChoice\"\n    },\n  \"block_position_encoding\": true,\n  \"checkpoint_activations\": false,\n  \"checkpoint_num_layers\": 1,\n  \"embedding_dropout_prob\": 0.1,\n  \"hidden_size\": 1024,\n  \"initializer_range\": 0.02,\n  \"max_sequence_length\": 1024,\n  \"model_type\": \"glm\",\n  \"num_attention_heads\": 16,\n  \"num_layers\": 24,\n  \"output_dropout_prob\": 0.1,\n  \"output_predict\": true,\n  \"parallel_output\": true,\n  \"pool_token\": \"cls\",\n  \"relative_encoding\": false,\n  \"spell_func\": \"lstm\",\n  \"spell_length\": null,\n  \"torch_dtype\": \"float32\",\n  \"vocab_size\": 50048,\n  \"pad_token_id\": 50000\n}"
  },
  {
    "path": "src/resources/models/glm-350M-chinese/configuration_glm.py",
    "content": "# coding=utf-8\n# Copyright 2022 shunxing1234 and The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\" GLM model configuration \"\"\"\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\nGLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {\n    \"shunxing1234/GLM\": \"https://huggingface.co/shunxing1234/GLM/resolve/main/config.json\",\n    # See all GLM models at https://huggingface.co/models?filter=glm\n}\n\n\nclass GLMConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`~GLMModel`].\n    It is used to instantiate an GLM model according to the specified arguments, defining the model\n    architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of\n    the GLM [shunxing1234/GLM-base-cased](https://huggingface.co/shunxing1234/GLM-base-cased) architecture.\n    Configuration objects inherit from  [`PretrainedConfig`] and can be used\n    to control the model outputs. Read the documentation from  [`PretrainedConfig`]\n    for more information.\n    Args:\n        vocab_size (`int`, *optional*, defaults to 30522):\n            Vocabulary size of the GLM model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`~GLMModel`] or\n            [`~TFGLMModel`].\n        hidden_size (`int`, *optional*, defaults to 768):\n            Dimension of the encoder layers and the pooler layer.\n        num_hidden_layers (`int`, *optional*, defaults to 12):\n            Number of hidden layers in the Transformer encoder.\n        num_attention_heads (`int`, *optional*, defaults to 12):\n            Number of attention heads for each attention layer in the Transformer encoder.\n        intermediate_size (`int`, *optional*, defaults to 3072):\n            Dimension of the \"intermediate\" (i.e., feed-forward) layer in the Transformer encoder.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"gelu\"`):\n            The non-linear activation function (function or string) in the encoder and pooler.\n            If string, `\"gelu\"`, `\"relu\"`, `\"selu\"` and `\"gelu_new\"` are supported.\n        hidden_dropout_prob (`float`, *optional*, defaults to 0.1):\n            The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.\n        attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):\n            The dropout ratio for the attention probabilities.\n        max_position_embeddings (`int`, *optional*, defaults to 512):\n            The maximum sequence length that this model might ever be used with.\n            Typically set this to something large just in case (e.g., 512 or 1024 or 2048).\n        type_vocab_size (`int`, *optional*, defaults to 2):\n            The vocabulary size of the `token_type_ids` passed when calling [`~GLMModel`] or\n            [`~TFGLMModel`].\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        layer_norm_eps (`float`, *optional*, defaults to 1e-12):\n            The epsilon used by the layer normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        Example:\n    ```python\n    >>> from transformers import GLMModel, GLMConfig\n    >>> # Initializing a GLM shunxing1234/GLM-base-cased style configuration\n    >>> configuration = GLMConfig()\n    >>> # Initializing a model from the shunxing1234/GLM-base-cased style configuration\n    >>> model = GLMModel(configuration)\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\n\"\"\"\n    model_type = \"glm\"\n    attribute_map = {\n        \"num_hidden_layers\": \"num_layers\"\n    }\n\n    def __init__(\n            self,\n            num_layers=24,\n            vocab_size=30592,\n            hidden_size=1024,\n            num_attention_heads=16,\n            embedding_dropout_prob=0.1,\n            attention_dropout_prob=0.1,\n            output_dropout_prob=0.1,\n            max_sequence_length=512,\n            checkpoint_activations=False,\n            checkpoint_num_layers=1,\n            parallel_output=True,\n            relative_encoding=False,\n            block_position_encoding=True,\n            output_predict=False,\n            spell_length=None,\n            spell_func=\"lstm\",\n            attention_scale=1.0,\n            initializer_range=0.02,\n            pool_token=\"cls\",\n            **kwargs\n    ):\n        self.num_layers = num_layers\n        self.vocab_size = vocab_size\n        self.hidden_size = hidden_size\n        self.num_attention_heads = num_attention_heads\n        self.embedding_dropout_prob = embedding_dropout_prob\n        self.attention_dropout_prob = attention_dropout_prob\n        self.output_dropout_prob = output_dropout_prob\n        self.max_sequence_length = max_sequence_length\n        self.checkpoint_activations = checkpoint_activations\n        self.checkpoint_num_layers = checkpoint_num_layers\n        self.parallel_output = parallel_output\n        self.relative_encoding = relative_encoding\n        self.block_position_encoding = block_position_encoding\n        self.output_predict = output_predict\n        self.spell_length = spell_length\n        self.spell_func = spell_func\n        self.attention_scale = attention_scale\n        self.initializer_range = initializer_range\n        self.pool_token = pool_token\n\n        super().__init__(**kwargs)"
  },
  {
    "path": "src/resources/models/glm-350M-chinese/modeling_glm.py",
    "content": "# coding=utf-8\n# Copyright 2022 shunxing1234 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\" PyTorch GLM model. \"\"\"\n\nimport math\n\nimport torch\nimport torch.utils.checkpoint\nimport torch.nn.functional as F\nfrom torch.nn import init, LayerNorm, Linear, CrossEntropyLoss\n\nfrom transformers.activations import gelu\nfrom transformers.utils import (\n    add_code_sample_docstrings,\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n)\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPastAndCrossAttentions,\n    ModelOutput,\n)\n\nfrom transformers.modeling_utils import (\n    PreTrainedModel,\n)\nfrom .configuration_glm import GLMConfig\nfrom torch.nn.parameter import Parameter\n\n_CHECKPOINT_FOR_DOC = \"shunxing1234/GLM\"\n_CONFIG_FOR_DOC = \"GLMConfig\"\n_TOKENIZER_FOR_DOC = \"GLMTokenizer\"\n\nGLM_PRETRAINED_MODEL_ARCHIVE_LIST = [\n    \"shunxing1234/GLM\",\n    # See all GLM models at https://huggingface.co/models?filter=glm\n]\n\n\ndef unscaled_init_method(sigma):\n    \"\"\"Init method based on N(0, sigma).\"\"\"\n\n    def init_(tensor):\n        return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)\n\n    return init_\n\n\ndef scaled_init_method(mean, std, num_layers):\n    \"\"\"Init method based on N(0, sigma/sqrt(2*num_layers).\"\"\"\n    std = std / math.sqrt(2.0 * num_layers)\n\n    def init_(tensor):\n        return torch.nn.init.normal_(tensor, mean=mean, std=std)\n\n    return init_\n\n\ndef ensure_divisibility(numerator, denominator):\n    \"\"\"Ensure that numerator is divisible by the denominator.\"\"\"\n    assert numerator % denominator == 0, '{} is not divisible by {}'.format(\n        numerator, denominator)\n\n\ndef divide(numerator, denominator):\n    \"\"\"Ensure that numerator is divisible by the denominator and return\n    the division value.\"\"\"\n    ensure_divisibility(numerator, denominator)\n    return numerator // denominator\n\n\ndef split_tensor_along_last_dim(tensor, num_partitions,\n                                contiguous_split_chunks=False):\n    \"\"\"Split a tensor along its last dimension.\n    Arguments:\n        tensor: input tensor.\n        num_partitions: number of partitions to split the tensor\n        contiguous_split_chunks: If True, make each chunk contiguous\n                                 in memory.\n    \"\"\"\n    # Get the size and dimension.\n    last_dim = tensor.dim() - 1\n    last_dim_size = divide(tensor.size()[last_dim], num_partitions)\n    # Split.\n    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)\n    # Note: torch.split does not create contiguous tensors by default.\n    if contiguous_split_chunks:\n        return tuple(chunk.contiguous() for chunk in tensor_list)\n\n    return tensor_list\n\n\nclass MLP(torch.nn.Module):\n    \"\"\"MLP for GPT2.\n    MLP will take the input with h hidden state, project it to 4*h\n    hidden dimension, perform gelu transformation, and project the\n    state back into h hidden dimension. At the end, dropout is also\n    applied.\n    Arguments:\n        hidden_size: The hidden size of the self attention.\n        output_dropout_prob: dropout probability for the outputs\n                             after self attention and final output.\n        init_method: initialization method used for the weights. Note\n                     that all biases are initialized to zero and\n                     layernorm weight are initialized to one.\n        output_layer_init_method: output layer initialization. If None,\n                                  use `init_method`.\n    \"\"\"\n\n    def __init__(self, hidden_size, output_dropout_prob, init_method,\n                 output_layer_init_method=None):\n        super(MLP, self).__init__()\n        # Set output layer initialization if not provided.\n        if output_layer_init_method is None:\n            output_layer_init_method = init_method\n        # Project to 4h.\n        self.dense_h_to_4h = Linear(hidden_size, 4 * hidden_size)\n\n        # Project back to h.\n        self.dense_4h_to_h = Linear(\n            4 * hidden_size,\n            hidden_size)\n\n        self.dropout = torch.nn.Dropout(output_dropout_prob)\n\n    def forward(self, hidden_states):\n        # [b, s, 4hp]\n        intermediate_parallel = self.dense_h_to_4h(hidden_states)\n        intermediate_parallel = gelu(intermediate_parallel)\n\n        # [b, s, h]\n        output = self.dense_4h_to_h(intermediate_parallel)\n        output = self.dropout(output)\n        return output\n\n\nclass VocabEmbedding(torch.nn.Module):\n    \"\"\"Embedding parallelized in the vocabulary dimension.\n    This is mainly adapted from torch.nn.Embedding and all the default\n    values are kept.\n    Arguments:\n        num_embeddings: vocabulary size.\n        embedding_dim: size of hidden state.\n        init_method: method to initialize weights.\n    \"\"\"\n\n    def __init__(self, config):\n        super(VocabEmbedding, self).__init__()\n        # Keep the input dimensions.\n        self.num_embeddings = config.vocab_size\n        self.embedding_dim = config.hidden_size\n        # Set the detauls for compatibility.\n        self.padding_idx = None\n        self.max_norm = None\n        self.norm_type = 2.\n        self.scale_grad_by_freq = False\n        self.sparse = False\n        self._weight = None\n\n        self.vocab_start_index = 0\n        self.vocab_end_index = self.num_embeddings\n\n        # Allocate weights.\n        weight = torch.Tensor(self.num_embeddings, self.embedding_dim)\n        self.register_buffer(\"weight\", weight)\n\n        # And initialize.\n        init.xavier_normal_(self.weight)\n\n    def forward(self, input_):\n        # Get the embeddings.\n        output = F.embedding(input_, self.weight,\n                             self.padding_idx, self.max_norm,\n                             self.norm_type, self.scale_grad_by_freq,\n                             self.sparse)\n        return output\n\n\nclass PositionalEmbedding(torch.nn.Module):\n\n    def __init__(self, hidden_size):\n        super(PositionalEmbedding, self).__init__()\n\n        self.hidden_size = hidden_size\n\n        inv_freq = 1 / (10000 ** (torch.arange(0.0, hidden_size, 2.0) / hidden_size))\n        self.register_buffer('inv_freq', inv_freq)\n\n    def forward(self, pos_seq, bsz=None):\n        sinusoid_inp = torch.ger(pos_seq, self.inv_freq)\n        pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)\n\n        if bsz is not None:\n            return pos_emb[None, :, :].expand(bsz, -1, -1)\n        else:\n            return pos_emb[None, :, :]\n\n\nclass SelfAttention(torch.nn.Module):\n    \"\"\"self-attention layer for GLM.\n    Self-attention layer takes input with size [b, s, h] where b is\n    the batch size, s is the sequence lenght, and h is the hidden size\n    and creates output of the same size.\n    Arguments:\n        hidden_size: total hidden size of the layer (h).\n        num_attention_heads: number of attention heads (n). Note that we\n                             require n to be divisible by number of GPUs\n                             used to parallelize the model. Also, we\n                             require hidden size to be divisible by n.\n        attention_dropout_prob: dropout probability for the attention scores.\n        init_method: weight initialization.\n        output_layer_init_method: output layer initialization. If None, use\n                                  `init_method`.\n    We use the following notation:\n        h: hidden_size\n        n: num_attention_heads\n        p: number of partitions\n        np: n/p\n        hp: h/p\n        hn: h/n\n        b: batch size\n        s: sequence length\n    \"\"\"\n\n    def __init__(self, hidden_size, num_attention_heads,\n                 attention_dropout_prob, output_dropout_prob,\n                 init_method, output_layer_init_method=None,\n                 attention_scale=1.0):\n        super(SelfAttention, self).__init__()\n        # Set output layer initialization if not provided.\n        if output_layer_init_method is None:\n            output_layer_init_method = init_method\n        # Per attention head and per partition values.\n        self.hidden_size = hidden_size\n        self.hidden_size_per_attention_head = divide(hidden_size,\n                                                     num_attention_heads)\n\n        self.num_attention_heads = num_attention_heads\n        self.attention_scale = attention_scale\n        # Strided linear layer.\n        self.query_key_value = Linear(hidden_size, 3 * hidden_size)\n\n        # Dropout. Note that for a single iteration, this layer will generate\n        # different outputs on different number of parallel partitions but\n        # on average it should not be partition dependent.\n        self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)\n\n        # Output.\n        self.dense = Linear(hidden_size,\n                            hidden_size)\n        self.output_dropout = torch.nn.Dropout(output_dropout_prob)\n\n    def _transpose_for_scores(self, tensor):\n        \"\"\"Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with\n        size [b, np, s, hn].\n        \"\"\"\n        new_tensor_shape = tensor.size()[:-1] + \\\n                           (self.num_attention_heads,\n                            self.hidden_size_per_attention_head)\n        tensor = tensor.view(*new_tensor_shape)\n        return tensor.permute(0, 2, 1, 3)\n\n    def forward(self, hidden_states, ltor_mask, mem=None):\n        # hidden_states: [b, s, h]\n        # ltor_mask: [b,1,s,s]\n\n        # Attention heads. [b, s, hp]\n        query_length = hidden_states.size(1)\n        # self attention\n        if mem is None:\n            mixed_x_layer = self.query_key_value(hidden_states)\n            (mixed_query_layer,\n             mixed_key_layer,\n             mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)\n        else:\n            cat = torch.cat((mem, hidden_states), 1)\n            mixed_x_layer = self.query_key_value(cat)\n            (mixed_query_layer,\n             mixed_key_layer,\n             mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)\n            mixed_query_layer = mixed_query_layer[:, -query_length:]\n\n        # Reshape and transpose [b, np, s, hn]\n        query_layer = self._transpose_for_scores(mixed_query_layer)\n        key_layer = self._transpose_for_scores(mixed_key_layer)\n        value_layer = self._transpose_for_scores(mixed_value_layer)\n\n        if self.attention_scale > 1.0:\n            # Raw attention scores. [b, np, s, s]\n            attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_scale),\n                                            key_layer.transpose(-1, -2) / math.sqrt(\n                                                self.hidden_size_per_attention_head * self.attention_scale))\n        else:\n            attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2) / math.sqrt(\n                self.hidden_size_per_attention_head))\n\n        # Apply the left to right attention mask.\n        ltor_mask = ltor_mask.type_as(attention_scores)\n        attention_scores = torch.mul(attention_scores, ltor_mask)\n        if self.attention_scale > 1.0:\n            max_attention_scores = attention_scores.max(dim=-1, keepdim=True)[0]\n            attention_scores -= max_attention_scores\n            attention_scores *= self.attention_scale\n\n        attention_scores = attention_scores + (-65504.0) * (1.0 - ltor_mask)\n        # Attention probabilities. [b, np, s, s]\n        attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)\n        # This is actually dropping out entire tokens to attend to, which might\n        # seem a bit unusual, but is taken from the original Transformer paper.\n        # with get_cuda_rng_tracker().fork():\n        attention_probs = self.attention_dropout(attention_probs)\n\n        # Context layer.\n        # [b, np, s, hn]\n        context_layer = torch.matmul(attention_probs, value_layer)\n        # [b, s, np, hn]\n        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()\n        new_context_layer_shape = context_layer.size()[:-2] + \\\n                                  (self.hidden_size,)\n        # [b, s, hp]\n        context_layer = context_layer.view(*new_context_layer_shape)\n\n        # Output. [b, s, h]\n        output = self.dense(context_layer)\n        output = self.output_dropout(output)\n\n        return output\n\n\nclass GLMBlock(torch.nn.Module):\n    \"\"\"A single layer transformer for GLM.\n    We use the following notation:\n        h: hidden size\n        n: number of attention heads\n        b: batch size\n        s: sequence length\n    Transformore layer takes input with size [b, s, h] and returns an\n    output of the same size.\n    Arguments:\n        hidden_size: The hidden size of the self attention.\n        num_attention_heads: number of attention head in the self\n                             attention.\n        attention_dropout_prob: dropout probability of the attention\n                                score in self attention.\n        output_dropout_prob: dropout probability for the outputs\n                             after self attention and final output.\n        layernorm_epsilon: epsilon used in layernorm to avoid\n                           division by zero.\n        init_method: initialization method used for the weights. Note\n                     that all biases are initialized to zero and\n                     layernorm weight are initialized to one.\n        output_layer_init_method: output layers (attention output and\n                                  mlp output) initialization. If None,\n                                  use `init_method`.\n    \"\"\"\n\n    def __init__(self,\n                 hidden_size,\n                 num_attention_heads,\n                 attention_dropout_prob,\n                 output_dropout_prob,\n                 layernorm_epsilon,\n                 init_method,\n                 output_layer_init_method=None,\n                 attention_scale=1.0):\n        super(GLMBlock, self).__init__()\n        # Set output layer initialization if not provided.\n        if output_layer_init_method is None:\n            output_layer_init_method = init_method\n\n        # Layernorm on the input data.\n        self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)\n\n        # Self attention.\n        self.attention = SelfAttention(\n            hidden_size,\n            num_attention_heads,\n            attention_dropout_prob,\n            output_dropout_prob,\n            init_method,\n            output_layer_init_method=output_layer_init_method,\n            attention_scale=attention_scale)\n\n        # Layernorm on the input data.\n        self.post_attention_layernorm = LayerNorm(hidden_size,\n                                                  eps=layernorm_epsilon)\n\n        # MLP\n        self.mlp = MLP(\n            hidden_size,\n            output_dropout_prob,\n            init_method,\n            output_layer_init_method=output_layer_init_method)\n\n    def forward(self, hidden_states, ltor_mask, mem=None):\n        # hidden_states: [b, s, h]\n        # ltor_mask: [b,1, s,s]\n\n        # Layer norm at the begining of the transformer layer.\n        layernorm_output = self.input_layernorm(hidden_states)\n        mem = self.input_layernorm(mem) if mem is not None else None\n        # Self attention.\n        attention_output = self.attention(layernorm_output, ltor_mask, mem)\n        # Residual connection.\n        layernorm_input = hidden_states + attention_output\n        # Layer norm post the self attention.\n        layernorm_output = self.post_attention_layernorm(layernorm_input)\n        # MLP.\n        mlp_output = self.mlp(layernorm_output)\n        # Second residual connection.\n        output = layernorm_input + mlp_output\n\n        return output\n\n\nclass GLMStack(torch.nn.Module):\n    \"\"\"GLM transformer.\n    This module takes input from embedding layer and it's output can\n    be used directly by a logit layer. It consists of L (num-layers)\n    blocks of:\n        layer norm\n        self attention\n        residual connection\n        layer norm\n        mlp\n        residual connection\n    followed by a final layer norm.\n    Arguments:\n        num_layers: Number of transformer layers.\n        hidden_size: The hidden size of the self attention.\n        num_attention_heads: number of attention head in the self\n                             attention.\n        attention_dropout_prob: dropout probability of the attention\n                                score in self attention.\n        output_dropout_prob: dropout probability for the outputs\n                             after self attention and final output.\n        checkpoint_activations: if True, checkpoint activations.\n        checkpoint_num_layers: number of layers to checkpoint. This\n                               is basically the chunk size in checkpoitning.\n        layernorm_epsilon: epsilon used in layernorm to avoid\n                           division by zero.\n        init_method_std: standard deviation of the init method which has\n                         the form N(0, std).\n        use_scaled_init_for_output_weights: If Ture use 1/sqrt(2*num_layers)\n                                            scaling for the output weights (\n                                            output of self attention and mlp).\n    \"\"\"\n\n    def __init__(self,\n                 num_layers,\n                 hidden_size,\n                 num_attention_heads,\n                 max_sequence_length,\n                 embedding_dropout_prob,\n                 attention_dropout_prob,\n                 output_dropout_prob,\n                 checkpoint_activations,\n                 checkpoint_num_layers=1,\n                 layernorm_epsilon=1.0e-5,\n                 init_method_std=0.02,\n                 use_scaled_init_for_output_weights=True,\n                 block_position_encoding=False,\n                 attention_scale=1.0,\n                 ):\n        super(GLMStack, self).__init__()\n        self.hidden_size = hidden_size\n        # Store activation checkpoiting flag.\n        self.checkpoint_activations = checkpoint_activations\n        self.checkpoint_num_layers = checkpoint_num_layers\n\n        output_layer_init_method = None\n        if use_scaled_init_for_output_weights:\n            output_layer_init_method = scaled_init_method(0.0, init_method_std,\n                                                          num_layers)\n        # Embeddings dropout\n        self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)\n        self.block_position_encoding = block_position_encoding\n\n        # Position embedding (serial).\n        if block_position_encoding:\n            self.position_embeddings = torch.nn.Embedding(max_sequence_length + 1, hidden_size)\n            self.block_position_embeddings = torch.nn.Embedding(max_sequence_length + 1, hidden_size)\n            torch.nn.init.normal_(self.block_position_embeddings.weight, mean=0.0, std=init_method_std)\n        else:\n            self.position_embeddings = torch.nn.Embedding(max_sequence_length, hidden_size)\n        # Initialize the position embeddings.\n        torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)\n\n        def get_layer():\n\n            return GLMBlock(\n                hidden_size,\n                num_attention_heads,\n                attention_dropout_prob,\n                output_dropout_prob,\n                layernorm_epsilon,\n                unscaled_init_method(init_method_std),\n                output_layer_init_method=output_layer_init_method,\n                attention_scale=attention_scale)\n\n        # Transformer layers.\n        self.layers = torch.nn.ModuleList(\n            [get_layer() for _ in range(num_layers)])\n\n        # Final layer norm before output.\n        self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)\n\n    def forward(self, hidden_states, position_ids, attention_mask, memory_states=None):\n\n        batch_size, query_length = hidden_states.size()[:2]\n        memory_length = memory_states[0].size(1) if memory_states else 0\n        # attention mask is the beginning postion of B region, \\in [0, query_len)\n        is_scalar = torch.numel(attention_mask) == 1\n        is_sep = is_scalar or torch.numel(attention_mask) == batch_size\n        if is_sep:\n            sep = attention_mask.item() if is_scalar else attention_mask\n\n            # conventional transformer\n            def build_mask_matrix(seq_length, sep, memory_length=0):\n                m = hidden_states.new_ones((1, seq_length, seq_length))\n                m = torch.tril(m)\n                if is_scalar:\n                    m[0, :, :int(sep)] = 1\n                else:\n                    m = m.expand(batch_size, -1, -1)\n                    ids = torch.arange(seq_length, device=sep.device, dtype=sep.dtype).view(1, -1)\n                    mask = ids < sep.view(-1, 1)\n                    m = m.masked_fill(mask.unsqueeze(1).expand_as(m), 1)\n                if memory_length > 0:\n                    m = m.expand(batch_size, -1, -1)\n                    m = torch.cat((hidden_states.new_ones((batch_size, seq_length, memory_length)), m), dim=2)\n                m = m.unsqueeze(1)\n                return m\n\n            attention_mask = build_mask_matrix(query_length, sep, memory_length=memory_length)\n        else:\n            if attention_mask.dim() == 2:\n                attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)\n            attention_mask = attention_mask[:, :, :, -query_length - memory_length:]\n\n        if self.block_position_encoding:\n            position_ids, block_position_ids = position_ids[:, 0], position_ids[:, 1]\n        position_embeddings = self.position_embeddings(position_ids)\n\n        hidden_states = hidden_states + position_embeddings\n        if self.block_position_encoding:\n            block_position_embeddings = self.block_position_embeddings(block_position_ids)\n            hidden_states = hidden_states + block_position_embeddings\n        hidden_states = self.embedding_dropout(hidden_states)\n\n        def check_detach(_hidden_states):\n            return _hidden_states.detach()\n\n        mem_layers = [check_detach(hidden_states)]\n\n        for i, layer in enumerate(self.layers):\n\n            args = [hidden_states, attention_mask]\n\n            def create_custom_forward(module):\n                def custom_forward(*inputs):\n                    # None for past_key_value\n                    return module(*inputs)\n\n                return custom_forward\n\n            mem_i = memory_states[i] if memory_states else None\n\n            if self.checkpoint_activations:\n                hidden_states = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(layer),\n                    hidden_states,\n                    mem=mem_i,\n                )\n            else:\n                hidden_states = layer(*args, mem=mem_i)\n            mem_layers.append(check_detach(hidden_states))\n\n        # Final layer norm.\n        output = self.final_layernorm(hidden_states)\n        mem_layers = self.update_mems(mem_layers, memory_states)\n        return (output, mem_layers)\n\n    def update_mems(self, hiddens, mems):\n        memory_length = mems[0].size(1) if mems else 0\n        query_length = hiddens[0].size(1)\n        new_memory_length = memory_length + query_length\n\n        new_mems = []\n        # with torch.no_grad():\n        for i in range(len(hiddens)):\n            if new_memory_length <= query_length:\n                new_mems.append(hiddens[i][:, -new_memory_length:])\n            else:\n                new_mems.append(torch.cat((mems[i][:, -new_memory_length + query_length:], hiddens[i]), dim=1))\n        return new_mems\n\n\nclass GLMPreTrainedModel(PreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and\n    a simple interface for downloading and loading pretrained models.\n    \"\"\"\n\n    config_class = GLMConfig\n    base_model_prefix = \"glm\"\n    supports_gradient_checkpointing = True\n    _keys_to_ignore_on_load_missing = [r\"position_ids\"]\n\n    def _init_weights(self, module):\n        \"\"\" Initialize the weights \"\"\"\n        if isinstance(module, torch.nn.Linear):\n            # Slightly different from the TF version which uses truncated_normal for initialization\n            # cf https://github.com/pytorch/pytorch/pull/5617\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, torch.nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n        elif isinstance(module, torch.nn.LayerNorm):\n            module.bias.data.zero_()\n            module.weight.data.fill_(1.0)\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, GLMModel):\n            module.gradient_checkpointing = value\n\n\nGLM_START_DOCSTRING = r\"\"\"\n    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general\n    usage and behavior.\n    Parameters:\n        config ([`~GLMConfig`]): Model configuration class with all the parameters of the model.\n            Initializing with a config file does not load the weights associated with the model, only the configuration.\n            Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\nGLM_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `({0})`):\n            Indices of input sequence tokens in the vocabulary.\n            Indices can be obtained using [`GLMTokenizer`].\n            See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n            [What are attention masks?](../glossary#attention-mask)\n        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):\n            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`:\n            - 0 corresponds to a *sentence A* token,\n            - 1 corresponds to a *sentence B* token.\n            [What are token type IDs?](../glossary#token-type-ids)\n        position_ids (`torch.LongTensor` of shape `({0})`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings.\n            Selected in the range `[0, config.max_position_embeddings - 1]`.\n            [What are position IDs?](../glossary#position-ids)\n        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):\n            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.\n            This is useful if you want more control over how to convert *input_ids* indices into associated vectors\n            than the model's internal embedding lookup matrix.\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare GLM Model transformer outputting raw hidden-states without any specific head on top.\",\n    GLM_START_DOCSTRING,\n)\nclass GLMModel(GLMPreTrainedModel):\n    \"\"\"\n    The model can behave as an encoder (with only self-attention) as well\n    as a decoder, in which case a layer of cross-attention is added between\n    the self-attention layers, following the architecture described in [Attention is\n    all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani,\n    Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.\n    To behave as an decoder the model needs to be initialized with the\n    `is_decoder` argument of the configuration set to `True`.\n    To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder`\n    argument and `add_cross_attention` set to `True`; an\n    `encoder_hidden_states` is then expected as an input to the forward pass.\n    \"\"\"\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.config = config\n        self.output_predict = config.output_predict\n        # Word embeddings (parallel).\n        self.word_embeddings = VocabEmbedding(config)\n\n        # Transformer\n        self.transformer = GLMStack(config.num_layers,\n                                    config.hidden_size,\n                                    config.num_attention_heads,\n                                    config.max_sequence_length,\n                                    config.embedding_dropout_prob,\n                                    config.attention_dropout_prob,\n                                    config.output_dropout_prob,\n                                    config.checkpoint_activations,\n                                    config.checkpoint_num_layers,\n                                    attention_scale=config.attention_scale,\n                                    block_position_encoding=config.block_position_encoding)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    @add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING.format(\"batch_size, sequence_length\"))\n    @add_code_sample_docstrings(\n        processor_class=_TOKENIZER_FOR_DOC,\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=BaseModelOutputWithPastAndCrossAttentions,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n            self,\n            input_ids=None,\n            position_ids=None,\n            attention_mask=None,\n            mems=None,\n            **kwargs\n    ):\n        batch_size = input_ids.size(0)\n        words_embeddings = self.word_embeddings(input_ids)\n        embeddings = words_embeddings\n\n        device = input_ids.device\n        input_shape = input_ids.size()\n\n        if position_ids is None:\n            position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device)\n            block_position_ids = torch.zeros(input_shape[-1], dtype=torch.long, device=device)\n            position_ids = torch.stack((position_ids, block_position_ids), dim=0).unsqueeze(0)\n        if attention_mask is None:\n            attention_mask = torch.zeros(batch_size)\n        # Transformer.\n        transformer_output = self.transformer(embeddings, position_ids, attention_mask, mems)\n        logits, hidden_layers = transformer_output\n        # outputs = hidden_layers\n        if self.output_predict:\n            # Parallel logits.\n            # logits_parallel = mpu.copy_to_model_parallel_region(\n            #     logits)\n            logits = F.linear(logits, self.word_embeddings.weight)\n\n        return ModelOutput(\n            logits=logits,\n            mems=hidden_layers,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"GLM Model transformer for multiple choice classification\"\"\",\n    GLM_START_DOCSTRING\n)\nclass GLMForMultipleChoice(GLMPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.glm = GLMModel(config)\n        self.post_init()\n\n    def forward(\n            self,\n            input_ids=None,\n            position_ids=None,\n            attention_mask=None,\n            choice_ids=None,\n            choice_indices=None,\n            labels=None,\n            mems=None,\n            **kwargs\n    ):\n        model_output = self.glm.forward(input_ids, position_ids, attention_mask, mems=mems, **kwargs)\n        lm_logits = model_output.logits\n        log_probs = []\n        for output, choices, choice_index in zip(F.log_softmax(lm_logits, dim=-1), choice_ids, choice_indices):\n            log_probs_single = []\n            for choice, choice_target_id in zip(choices, choice_index):\n                tmp = output[choice_target_id, choice]\n                log_probs_single.append(tmp.sum())\n            log_probs.append(torch.stack(log_probs_single))\n        log_probs = torch.stack(log_probs)\n        loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss()\n            loss = loss_fct(log_probs, labels)\n        return ModelOutput(\n            loss=loss,\n            logits=log_probs,\n            lm_logits=lm_logits,\n            mems=model_output.mems\n        )\n\n@add_start_docstrings(\n    \"\"\"GLM Model transformer with a `language modeling` head on top\"\"\",\n    GLM_START_DOCSTRING,\n)\nclass GLMForConditionalGeneration(GLMPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.glm = GLMModel(config)\n        self.post_init()\n\n    def _reorder_cache(self, past, beam_idx):\n        # if decoder past is not included in output\n        # speedy decoding is disabled and no need to reorder\n        if past is None:\n            return past\n        reordered_decoder_past = ()\n        for layer_past_states in past:\n            # get the correct batch idx from layer past batch dim\n            reordered_decoder_past = reordered_decoder_past + (\n                layer_past_states.index_select(0, beam_idx.to(layer_past_states.device)),)\n        return reordered_decoder_past\n\n    def prepare_inputs_for_generation(self, input_ids, past=None, position_ids=None, generation_attention_mask=None,\n                                      **kwargs):\n        # only last token for inputs_ids if past is defined in kwargs\n        attention_mask = generation_attention_mask\n        seq_length = input_ids.shape[1]\n        if past:\n            if position_ids is not None:\n                position_ids = position_ids[:, :, seq_length - 1].unsqueeze(-1)\n            if attention_mask is not None:\n                attention_mask = attention_mask[:, :, seq_length - 1, :seq_length].unsqueeze(-2)\n            input_ids = input_ids[:, -1].unsqueeze(-1)\n        else:\n            if position_ids is not None:\n                position_ids = position_ids[:, :, :seq_length]\n            if attention_mask is not None:\n                attention_mask = attention_mask[:, :, :seq_length, :seq_length]\n        return {\n            \"input_ids\": input_ids,\n            \"position_ids\": position_ids,\n            \"attention_mask\": attention_mask,\n            \"mems\": past,\n        }\n\n    def forward(\n            self,\n            input_ids=None,\n            position_ids=None,\n            attention_mask=None,\n            labels=None,\n            mems=None,\n            **kwargs\n    ):\n        model_output = self.glm.forward(input_ids, position_ids, attention_mask, mems=mems, **kwargs)\n        lm_logits = model_output.logits\n        loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss(ignore_index=self.config.pad_token_id)\n            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))\n        return ModelOutput(\n            loss=loss,\n            logits=lm_logits,\n            mems=model_output.mems\n        )"
  },
  {
    "path": "src/resources/models/glm-350M-chinese/tokenization_glm.py",
    "content": "import os\nfrom typing import Optional, Tuple, List, Union\nfrom shutil import copyfile\nimport torch\n\nfrom transformers import PreTrainedTokenizer, RobertaTokenizer, GPT2Tokenizer, BertTokenizer\nfrom transformers.utils import logging\nfrom transformers.tokenization_utils_base import BatchEncoding\nfrom transformers.models.auto.tokenization_auto import get_tokenizer_config\n# from transformers.utils import torch_required\nfrom transformers.utils.generic import _is_torch_device\nimport sentencepiece as spm\n\nlogger = logging.get_logger(__name__)\n\n\nclass GLMBatchEncoding(BatchEncoding):\n    # @torch_required\n    def to(self, device: Union[str, \"torch.device\"]) -> \"BatchEncoding\":\n        \"\"\"\n        Send all values to device by calling `v.to(device)` (PyTorch only).\n        Args:\n            device (`str` or `torch.device`): The device to put the tensors on.\n        Returns:\n            [`BatchEncoding`]: The same instance after modification.\n        \"\"\"\n\n        # This check catches things like APEX blindly calling \"to\" on all inputs to a module\n        # Otherwise it passes the casts down and casts the LongTensor containing the token idxs\n        # into a HalfTensor\n        if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int):\n            self.data = {k: v.to(device=device) if torch.is_tensor(v) else v for k, v in self.data.items()}\n        else:\n            logger.warning(f\"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.\")\n        return self\n\n\nclass GLMTokenizerMixin:\n    @property\n    def sop_token(self) -> Optional[str]:\n        return \"<|startofpiece|>\"\n\n    @property\n    def sop_token_id(self) -> Optional[int]:\n        \"\"\"\n        `Optional[int]`: Id of the start token in the vocabulary, used when training a model with autoregressive blank filling.\n        \"\"\"\n        return self.convert_tokens_to_ids(self.sop_token)\n\n    @property\n    def eop_token(self) -> Optional[str]:\n        return \"<|endofpiece|>\"\n\n    @property\n    def eop_token_id(self) -> Optional[int]:\n        \"\"\"\n        `Optional[int]`: Id of the end token in the vocabulary, used when training a model with autoregressive blank filling.\n        \"\"\"\n        return self.convert_tokens_to_ids(self.eop_token)\n\n    @property\n    def gmask_token_id(self) -> int:\n        return self.convert_tokens_to_ids(\"[gMASK]\")\n\n    @property\n    def smask_token_id(self) -> int:\n        return self.convert_tokens_to_ids(\"[sMASK]\")\n\n    @property\n    def mask_token_ids(self):\n        return [self.mask_token_id, self.smask_token_id, self.gmask_token_id]\n\n    def _build_input_for_multiple_choice(self, context, choices):\n        context_id = context[\"input_ids\"]\n        if torch.is_tensor(context_id):\n            context_id = context_id.tolist()\n\n        division = len(context_id)\n        mask_position = context_id.index(self.mask_token_id)\n\n        token = torch.tensor(context_id, dtype=torch.long)\n        attention_mask = [context[\"attention_mask\"].expand(division, -1)]\n        position_id = torch.arange(division, dtype=torch.long)\n        block_position_id = torch.zeros(division, dtype=torch.long)\n\n        choice_ids, choice_indices = [], []\n\n        for choice_str in choices:\n            choice = torch.tensor(self(choice_str, add_special_tokens=False, padding=False)['input_ids'],\n                                  dtype=torch.long)\n            choice_ids.append(choice)\n            choice_indices.append(torch.arange(len(token), len(token) + len(choice), dtype=torch.long))\n            attention_mask.append(torch.tril(torch.ones((len(choice), len(choice)), dtype=torch.long)))\n\n            token = torch.cat((token, torch.tensor([self.sop_token_id], dtype=torch.long), choice[:-1]))\n            position_id = torch.cat((position_id, torch.tensor([mask_position] * len(choice), dtype=torch.long)))\n            block_position_id = torch.cat((block_position_id, torch.arange(1, 1 + len(choice), dtype=torch.long)))\n\n        attention_mask = torch.block_diag(*attention_mask)\n        attention_mask[division:, :division] = context[\"attention_mask\"].unsqueeze(0)\n\n        return {\n            \"input_ids\": token,\n            \"position_ids\": torch.stack((position_id, block_position_id)),\n            \"attention_mask\": attention_mask,\n            \"choice_ids\": choice_ids,\n            \"choice_indices\": choice_indices\n        }\n\n    def _pad_batch(self, tokens, position_ids, attention_mask, max_seq_length):\n        pad_length = max_seq_length - len(tokens)\n        attention_mask = torch.nn.functional.pad(\n            attention_mask,\n            (0, pad_length, 0, pad_length),\n            mode=\"constant\",\n            value=0,\n        )\n        tokens = torch.cat((tokens, torch.zeros(pad_length, dtype=torch.long)))\n        position_ids = torch.cat((position_ids, position_ids[..., -1:].expand(-1, pad_length)), dim=-1)\n        return tokens, position_ids, attention_mask\n\n    def _collate(self, samples):\n        TILE = 1\n        length_to_pad = (max(map(lambda spl: len(spl[\"input_ids\"]), samples)) + TILE - 1) // TILE * TILE\n\n        token_batch, position_id_batch, attention_mask_batch = [], [], []\n        choices_batch, choice_target_ids_batch = [], []\n\n        for sample in samples:\n            token, position_id, attention_mask = self._pad_batch(\n                sample[\"input_ids\"], sample[\"position_ids\"], sample[\"attention_mask\"], length_to_pad\n            )\n            token_batch.append(token)\n            position_id_batch.append(position_id)\n            attention_mask_batch.append(attention_mask)\n            choices_batch.append(sample[\"choice_ids\"])\n            choice_target_ids_batch.append(sample[\"choice_indices\"])\n        return {\n            \"input_ids\": torch.stack(token_batch),\n            \"position_ids\": torch.stack(position_id_batch),\n            \"attention_mask\": torch.stack(attention_mask_batch).unsqueeze(1),\n            \"choice_ids\": choices_batch,\n            \"choice_indices\": choice_target_ids_batch,\n        }\n\n    def build_inputs_for_multiple_choice(self, model_input: BatchEncoding, choices, max_length=None):\n        samples = [{key: value[i] for key, value in model_input.items()} for i in range(len(model_input[\"input_ids\"]))]\n        samples = [self._build_input_for_multiple_choice(sample, choice) for sample, choice in\n                   zip(samples, choices)]\n        inputs = self._collate(samples)\n        return GLMBatchEncoding(inputs)\n\n    def build_inputs_for_generation(self, model_input: BatchEncoding, max_gen_length=512, targets=None,\n                                    padding=False):\n        mask_ids = self.mask_token_ids\n        input_ids = model_input.input_ids\n        batch_size, seq_length = input_ids.shape[:2]\n        labels = None\n        # create labels\n        if targets is not None:\n            is_batched = isinstance(targets, (list, tuple))\n            targets = self(targets, add_special_tokens=False, padding=False).input_ids\n            if not is_batched:\n                targets = [targets]\n            assert len(targets) == len(input_ids)\n            targets = [target[:(max_gen_length-1)] + [self.eop_token_id] for target in targets]\n            if not padding:\n                max_gen_length = max(map(len, targets))\n            targets = [[self.sop_token_id] + target for target in targets]\n            labels = [target[1:] for target in targets]\n            targets = [target + [self.pad_token_id] * (max_gen_length + 1 - len(target)) for target in targets]\n            labels = [label + [self.pad_token_id] * (max_gen_length - len(label)) for label in labels]\n            targets = torch.tensor(targets, dtype=input_ids.dtype, device=input_ids.device)\n            labels = torch.tensor(labels, dtype=input_ids.dtype, device=input_ids.device)\n            labels = torch.cat((input_ids.new_full((batch_size, seq_length), self.pad_token_id), labels), dim=1)\n        # create position ids\n        position_id, block_position_id = list(range(seq_length)), [0 for _ in range(seq_length)]\n        position_ids, block_position_ids = [], []\n        for i in range(batch_size):\n            mask_positions = []\n            for mask_id in mask_ids:\n                mask_positions += (input_ids[i] == mask_id).nonzero(as_tuple=True)[0].tolist()\n            if not mask_positions:\n                raise ValueError(\"Cannot find mask token in the input\")\n            mask_positions.sort()\n            mask_pos = mask_positions[0]\n            position_ids.append(position_id + [mask_pos] * max_gen_length)\n            block_position_ids.append(block_position_id + list(range(1, max_gen_length + 1)))\n        position_ids = torch.tensor(position_ids, dtype=input_ids.dtype, device=input_ids.device)\n        block_position_ids = torch.tensor(block_position_ids, dtype=input_ids.dtype, device=input_ids.device)\n        position_ids = torch.stack((position_ids, block_position_ids), dim=1)\n        # create attention mask\n        attention_mask = model_input.attention_mask\n        attention_mask = attention_mask.unsqueeze(1).expand(-1, seq_length + max_gen_length, -1)\n        generation_attention_mask = torch.cat([attention_mask.new_zeros((seq_length, max_gen_length)),\n                                               torch.tril(attention_mask.new_ones((max_gen_length, max_gen_length)))],\n                                              dim=0).unsqueeze(0).expand(batch_size, -1, -1)\n        attention_mask = torch.cat((attention_mask, generation_attention_mask), dim=2)\n        attention_mask = attention_mask.unsqueeze(1)\n        # create input ids\n        if targets is None:\n            input_ids = torch.cat((input_ids, input_ids.new_full((batch_size, 1), self.sop_token_id)), dim=-1)\n        else:\n            input_ids = torch.cat((input_ids, targets[:, :-1]), dim=1)\n\n        batch = {\"input_ids\": input_ids, \"position_ids\": position_ids}\n        if labels is None:\n            batch[\"generation_attention_mask\"] = attention_mask\n        else:\n            batch[\"attention_mask\"] = attention_mask\n            batch[\"labels\"] = labels\n\n        return BatchEncoding(batch)\n\n\nclass GLMRobertaTokenizer(RobertaTokenizer, GLMTokenizerMixin):\n    model_input_names = [\"input_ids\", \"position_ids\", \"attention_mask\"]\n    truncation_side: str = \"left\"\n\n    @property\n    def gmask_token_id(self) -> int:\n        raise NotImplementedError(\"The model doesn't support gMASK\")\n\n    @property\n    def smask_token_id(self) -> int:\n        raise NotImplementedError(\"The model doesn't support sMASK\")\n\n    @property\n    def mask_token_ids(self):\n        return [self.mask_token_id]\n\n\nclass GLMChineseTokenizer(PreTrainedTokenizer, GLMTokenizerMixin):\n    vocab_files_names = {\"vocab_file\": \"cog-pretrain.model\"}\n    truncation_side: str = \"left\"\n\n    def __init__(self, vocab_file, **kwargs):\n        super().__init__(**kwargs)\n        self.vocab_file = vocab_file\n        self.sp_model = spm.SentencePieceProcessor()\n        self.sp_model.Load(vocab_file)\n\n    @property\n    def vocab_size(self):\n        return len(self.sp_model)\n\n    def get_vocab(self):\n        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}\n        vocab.update(self.added_tokens_encoder)\n        return vocab\n\n    def _tokenize(self, text, **kwargs):\n        return self.sp_model.encode(text, out_type=str)\n\n    def _convert_token_to_id(self, token):\n        \"\"\"Converts a token (str) in an id using the vocab.\"\"\"\n        return self.sp_model.PieceToId(token)\n\n    def _convert_id_to_token(self, index):\n        \"\"\"Converts an index (integer) in a token (str) using the vocab.\"\"\"\n        return self.sp_model.IdToPiece(index)\n\n    def convert_tokens_to_string(self, tokens):\n        return self.sp_model.decode(tokens)\n\n    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:\n        if not os.path.isdir(save_directory):\n            logger.error(f\"Vocabulary path ({save_directory}) should be a directory\")\n            return\n        out_vocab_file = os.path.join(\n            save_directory, (filename_prefix + \"-\" if filename_prefix else \"\") + self.vocab_files_names[\"vocab_file\"]\n        )\n\n        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):\n            copyfile(self.vocab_file, out_vocab_file)\n        elif not os.path.isfile(self.vocab_file):\n            with open(out_vocab_file, \"wb\") as fi:\n                content_spiece_model = self.sp_model.serialized_model_proto()\n                fi.write(content_spiece_model)\n\n        return (out_vocab_file,)\n\n    def build_inputs_with_special_tokens(\n            self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None\n    ) -> List[int]:\n        \"\"\"\n        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and\n        adding special tokens. A BERT sequence has the following format:\n        - single sequence: ``[CLS] X [SEP]``\n        - pair of sequences: ``[CLS] A [SEP] B [SEP]``\n        Args:\n            token_ids_0 (:obj:`List[int]`):\n                List of IDs to which the special tokens will be added.\n            token_ids_1 (:obj:`List[int]`, `optional`):\n                Optional second list of IDs for sequence pairs.\n        Returns:\n            :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.\n        \"\"\"\n        #assert token_ids_1 is None\n        #cls = [self.cls_token_id]\n        #eos = [self.eos_token_id]\n        #return cls + token_ids_0 + eos\n        cls = [self.cls_token_id]\n        eos = [self.eos_token_id]\n        #eop = [self.eop_token_id]\n        #mask = [self.mask_token_id]\n        sep = [self.sep_token_id]\n        #token_ids_0 = cls + token_ids_0 + mask + eos\n        if token_ids_1 is None:\n            return cls + token_ids_0 + eos\n        else:\n            return cls + token_ids_0 + sep + token_ids_1 + eos\n\n\nclass GLMGPT2Tokenizer(GPT2Tokenizer, GLMTokenizerMixin):\n    model_input_names = [\"input_ids\", \"position_ids\", \"attention_mask\"]\n    truncation_side: str = \"left\"\n\n    def build_inputs_with_special_tokens(\n            self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None\n    ) -> List[int]:\n        \"\"\"\n        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and\n        adding special tokens. A BERT sequence has the following format:\n        - single sequence: ``[CLS] X [SEP]``\n        - pair of sequences: ``[CLS] A [SEP] B [SEP]``\n        Args:\n            token_ids_0 (:obj:`List[int]`):\n                List of IDs to which the special tokens will be added.\n            token_ids_1 (:obj:`List[int]`, `optional`):\n                Optional second list of IDs for sequence pairs.\n        Returns:\n            :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.\n        \"\"\"\n        assert token_ids_1 is None\n        cls = [self.cls_token_id]\n        eos = [self.eos_token_id]\n        return cls + token_ids_0 + eos\n\n\nclass GLMBertTokenizer(BertTokenizer, GLMTokenizerMixin):\n    model_input_names = [\"input_ids\", \"position_ids\", \"attention_mask\"]\n    truncation_side: str = \"left\"\n\n    @property\n    def gmask_token_id(self) -> int:\n        raise NotImplementedError(\"The model doesn't support gMASK\")\n\n    @property\n    def smask_token_id(self) -> int:\n        raise NotImplementedError(\"The model doesn't support sMASK\")\n\n    @property\n    def mask_token_ids(self):\n        return [self.mask_token_id]\n\n\nclass GLMTokenizer:\n    @classmethod\n    def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):\n        tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)\n        config_tokenizer_class = tokenizer_config.get(\"tokenizer_class\")\n        if config_tokenizer_class == \"GLMRobertaTokenizer\":\n            tokenizer_class = GLMRobertaTokenizer\n        elif config_tokenizer_class == \"GLMChineseTokenizer\":\n            tokenizer_class = GLMChineseTokenizer\n        elif config_tokenizer_class == \"GLMGPT2Tokenizer\":\n            tokenizer_class = GLMGPT2Tokenizer\n        elif config_tokenizer_class == \"GLMBertTokenizer\":\n            tokenizer_class = GLMBertTokenizer\n        else:\n            raise NotImplementedError(\"Not implemented tokenizer type:\", config_tokenizer_class)\n        return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)"
  },
  {
    "path": "src/resources/models/glm-350M-chinese/tokenizer_config.json",
    "content": "{\n  \"name_or_path\": \"BAAI/glm-large-chinese\",\n  \"eos_token\": \"<|endoftext|>\",\n  \"pad_token\": \"<|endoftext|>\",\n  \"cls_token\": \"[CLS]\",\n  \"mask_token\": \"[MASK]\",\n  \"unk_token\": \"[UNK]\",\n  \"sep_token\": \"[SEP]\",\n  \"additional_special_tokens\": [\"<|startofpiece|>\", \"<|endofpiece|>\", \"[gMASK]\", \"[sMASK]\"],\n  \"add_prefix_space\": false,\n  \"tokenizer_class\": \"GLMChineseTokenizer\",\n  \"use_fast\": false,\n  \"auto_map\": {\n    \"AutoTokenizer\": [\n      \"tokenization_glm.GLMChineseTokenizer\",\n      null\n      ]\n  }\n}"
  },
  {
    "path": "src/resources/models/pangu-13B/config.json",
    "content": "{\n  \"model_name_or_path\": \"sunzeyeah/pangu-13B\",\n  \"activation_function\": \"gelu\",\n  \"architectures\": [\n    \"GPTPanguForCausalLM\"\n  ],\n  \"attn_pdrop\": 0.1,\n  \"embd_pdrop\": 0.1,\n  \"hidden_size\": 5120,\n  \"initializer_range\": 0.02,\n  \"intermediate_size\": null,\n  \"layer_norm_epsilon\": 1e-05,\n  \"max_position_embeddings\": 1024,\n  \"model_type\": \"pangu\",\n  \"num_heads\": 40,\n  \"num_layers\": 40,\n  \"resid_pdrop\": 0.1,\n  \"scale_attn_weights\": true,\n  \"summary_activation\": null,\n  \"summary_first_dropout\": 0.1,\n  \"summary_proj_to_labels\": true,\n  \"summary_type\": \"cls_index\",\n  \"summary_use_proj\": true,\n  \"torch_dtype\": \"float32\",\n  \"use_cache\": true,\n  \"vocab_size\": 40064,\n  \"tokenizer_class\": \"GPTPanguTokenizer\",\n  \"auto_map\": {\n    \"AutoConfig\": \"configuration_gptpangu.GPTPanguConfig\",\n    \"AutoTokenizer\": [\"tokenization_gptpangu.GPTPanguTokenizer\", null],\n    \"AutoModelForCausalLM\": \"modeling_gptpangu.GPTPanguForCausalLM\"\n  },\n  \"pad_token_id\": 6\n}\n"
  },
  {
    "path": "src/resources/models/pangu-13B/configuration_gptpangu.py",
    "content": "from transformers.configuration_utils import PretrainedConfig\n\n\nclass GPTPanguConfig(PretrainedConfig):\n    model_type = \"gpt_pangu\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=40000,\n        max_position_embeddings=1024,\n        hidden_size=5120,\n        intermediate_size=None,\n        num_layers=40,\n        num_heads=40,\n        activation_function=\"gelu\",\n        resid_pdrop=0.1,\n        embd_pdrop=0.1,\n        attn_pdrop=0.1,\n        layer_norm_epsilon=1e-5,\n        scale_attn_weights=True,\n        initializer_range=0.02,\n        summary_type=\"cls_index\",\n        summary_use_proj=True,\n        summary_activation=None,\n        summary_proj_to_labels=True,\n        summary_first_dropout=0.1,\n        use_cache=True,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_layers = num_layers\n        self.num_heads = num_heads\n        self.activation_function = activation_function\n        self.resid_pdrop = resid_pdrop\n        self.embd_pdrop = embd_pdrop\n        self.attn_pdrop = attn_pdrop\n        self.layer_norm_epsilon = layer_norm_epsilon\n        self.scale_attn_weights = scale_attn_weights\n        self.initializer_range = initializer_range\n        self.summary_type = summary_type\n        self.summary_use_proj = summary_use_proj\n        self.summary_activation = summary_activation\n        self.summary_first_dropout = summary_first_dropout\n        self.summary_proj_to_labels = summary_proj_to_labels\n        self.use_cache = use_cache\n\n        super().__init__(**kwargs)\n"
  },
  {
    "path": "src/resources/models/pangu-13B/modeling_gptpangu.py",
    "content": "\"\"\"PyTorch PanguAlpha GPT2 Model\"\"\"\n# from .configuration_gptpangu import GPTPanguConfig\n\nfrom typing import Tuple\nimport math\n\nimport torch\nfrom torch import nn\n\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast\n\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\n\nclass GPTPanguAttention(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n\n        max_positions = config.max_position_embeddings\n        self.register_buffer(\n            \"bias\",\n            torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(\n                1, 1, max_positions, max_positions\n            ),\n        )\n        self.register_buffer(\"masked_bias\", torch.tensor(-1e4))\n\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.num_heads\n        self.head_dim = self.embed_dim // self.num_heads\n        if self.head_dim * self.num_heads != self.embed_dim:\n            raise ValueError(\n                f\"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads}).\"\n            )\n\n        self.scale_attn_weights = config.scale_attn_weights\n\n        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)\n        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)\n        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)\n        self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)\n\n        self.attn_dropout = nn.Dropout(config.attn_pdrop)\n        self.resid_dropout = nn.Dropout(config.resid_pdrop)\n\n\n    def _attn(self, query, key, value, attention_mask=None, head_mask=None):\n        attn_weights = torch.matmul(query, key.transpose(-1, -2))\n\n        if self.scale_attn_weights:\n            attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)\n\n        query_length, key_length = query.size(-2), key.size(-2)\n        causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()\n        attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))\n\n        if attention_mask is not None:\n            # Apply the attention mask\n            attn_weights = attn_weights + attention_mask\n\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1)\n\n        # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise\n        attn_weights = attn_weights.type(value.dtype)\n        attn_weights = self.attn_dropout(attn_weights)\n\n        # Mask heads if we want to\n        if head_mask is not None:\n            attn_weights = attn_weights * head_mask\n\n        attn_output = torch.matmul(attn_weights, value)\n\n        return attn_output, attn_weights\n\n    def _split_heads(self, tensor, num_heads, attn_head_size):\n        \"\"\"\n        Splits hidden_size dim into attn_head_size and num_heads\n        \"\"\"\n        new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)\n        tensor = tensor.view(*new_shape)\n        return tensor.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)\n\n    def _merge_heads(self, tensor, num_heads, attn_head_size):\n        \"\"\"\n        Merges attn_head_size dim and num_attn_heads dim into hidden_size\n        \"\"\"\n        tensor = tensor.permute(0, 2, 1, 3).contiguous()\n        new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)\n        return tensor.view(new_shape)\n\n    def forward(\n        self,\n        hidden_states,\n        layer_past=None,\n        attention_mask=None,\n        head_mask=None,\n        custom_query=None,\n        use_cache=False,\n        output_attentions=False,\n    ):\n        query = self.q_proj(custom_query) if custom_query is not None else self.q_proj(hidden_states)\n        key = self.k_proj(hidden_states)\n        value = self.v_proj(hidden_states)\n\n        query = self._split_heads(query, self.num_heads, self.head_dim)\n        key = self._split_heads(key, self.num_heads, self.head_dim)\n        value = self._split_heads(value, self.num_heads, self.head_dim)\n\n        if layer_past is not None:\n            past_key, past_value = layer_past\n            key = torch.cat((past_key, key), dim=-2)\n            value = torch.cat((past_value, value), dim=-2)\n\n        if use_cache is True:\n            present = (key, value)\n        else:\n            present = None\n\n        attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)\n\n        attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)\n        attn_output = self.c_proj(attn_output)\n        attn_output = self.resid_dropout(attn_output)\n\n        outputs = (attn_output, present)\n        if output_attentions:\n            outputs += (attn_weights,)\n\n        return outputs  # a, present, (attentions)\n\n\nclass GPTPanguMLP(nn.Module):\n    def __init__(self, intermediate_size, config):  # in MLP: intermediate_size= 4 * hidden_size\n        super().__init__()\n        embed_dim = config.hidden_size\n        self.c_fc = nn.Linear(embed_dim, intermediate_size)\n        self.c_proj = nn.Linear(intermediate_size, embed_dim)\n        self.act = ACT2FN[config.activation_function]\n        self.dropout = nn.Dropout(config.resid_pdrop)\n\n    def forward(self, hidden_states):\n        hidden_states = self.c_fc(hidden_states)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.c_proj(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        return hidden_states\n\n\nclass GPTPanguBlock(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        hidden_size = config.hidden_size\n        inner_dim = config.intermediate_size if config.intermediate_size is not None else 4 * hidden_size\n\n        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)\n        self.attn = GPTPanguAttention(config)\n        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)\n        self.mlp = GPTPanguMLP(inner_dim, config)\n\n    def forward(\n        self,\n        hidden_states,\n        layer_past=None,\n        attention_mask=None,\n        head_mask=None,\n        custom_query=None,\n        use_cache=False,\n        output_attentions=False,\n    ):\n        residual = hidden_states\n        hidden_states = self.ln_1(hidden_states)\n        attn_outputs = self.attn(\n            hidden_states,\n            layer_past=layer_past,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            custom_query=custom_query,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n        )\n        attn_output = attn_outputs[0]  # output_attn: a, present, (attentions)\n        outputs = attn_outputs[1:]\n        # residual connection\n        hidden_states = attn_output + residual\n\n        residual = hidden_states\n        hidden_states = self.ln_2(hidden_states)\n        feed_forward_hidden_states = self.mlp(hidden_states)\n        # residual connection\n        hidden_states = residual + feed_forward_hidden_states\n\n        if use_cache:\n            outputs = (hidden_states,) + outputs\n        else:\n            outputs = (hidden_states,) + outputs[1:]\n\n        return outputs  # hidden_states, present, (attentions, cross_attentions)\n\n\nclass GPTPanguPreTrainedModel(PreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained\n    models.\n    \"\"\"\n\n    # config_class = GPTPanguConfig\n    base_model_prefix = \"transformer\"\n    supports_gradient_checkpointing = True\n\n    def __init__(self, *inputs, **kwargs):\n        super().__init__(*inputs, **kwargs)\n\n    def _init_weights(self, module):\n        \"\"\"Initialize the weights.\"\"\"\n        if isinstance(module, (nn.Linear,)):\n            # Slightly different from the TF version which uses truncated_normal for initialization\n            # cf https://github.com/pytorch/pytorch/pull/5617\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n        elif isinstance(module, nn.LayerNorm):\n            module.bias.data.zero_()\n            module.weight.data.fill_(1.0)\n\n        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:\n        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale\n        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.\n        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/\n        #\n        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py\n        for name, p in module.named_parameters():\n            if \"c_proj\" in name and \"weight\" in name:\n                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block\n                p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.num_layers)))\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, GPTPanguModel):\n            module.gradient_checkpointing = value\n\n\nclass GPTPanguModel(GPTPanguPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.embed_dim = config.hidden_size\n\n        self.wte = nn.Embedding(config.vocab_size, self.embed_dim)\n        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)\n        self.wqe = nn.Embedding(config.max_position_embeddings, self.embed_dim)\n\n        self.drop = nn.Dropout(config.embd_pdrop)\n        self.h = nn.ModuleList([GPTPanguBlock(config) for _ in range(config.num_layers)])\n        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)\n\n        self.gradient_checkpointing = False\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.wte\n\n    def set_input_embeddings(self, new_embeddings):\n        self.wte = new_embeddings\n\n    def forward(\n        self,\n        input_ids=None,\n        past_key_values=None,\n        attention_mask=None,\n        token_type_ids=None,\n        position_ids=None,\n        head_mask=None,\n        inputs_embeds=None,\n        use_cache=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\"You cannot specify both input_ids and inputs_embeds at the same time\")\n        elif input_ids is not None:\n            input_shape = input_ids.size()\n            input_ids = input_ids.view(-1, input_shape[-1])\n            batch_size = input_ids.shape[0]\n        elif inputs_embeds is not None:\n            input_shape = inputs_embeds.size()[:-1]\n            batch_size = inputs_embeds.shape[0]\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n\n        device = input_ids.device if input_ids is not None else inputs_embeds.device\n\n        if token_type_ids is not None:\n            token_type_ids = token_type_ids.view(-1, input_shape[-1])\n        if position_ids is not None:\n            position_ids = position_ids.view(-1, input_shape[-1])\n\n        if past_key_values is None:\n            past_length = 0\n            past_key_values = tuple([None] * len(self.h))\n        else:\n            past_length = past_key_values[0][0].size(-2)\n        if position_ids is None:\n            position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)\n            position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])\n\n        # GPT2Attention mask.\n        if attention_mask is not None:\n            if batch_size <= 0:\n                raise ValueError(\"batch_size has to be defined and > 0\")\n            attention_mask = attention_mask.view(batch_size, -1)\n            # We create a 3D attention mask from a 2D tensor mask.\n            # Sizes are [batch_size, 1, 1, to_seq_length]\n            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]\n            # this attention mask is more simple than the triangular masking of causal attention\n            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.\n            attention_mask = attention_mask[:, None, None, :]\n\n            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for\n            # masked positions, this operation will create a tensor which is 0.0 for\n            # positions we want to attend and -10000.0 for masked positions.\n            # Since we are adding it to the raw scores before the softmax, this is\n            # effectively the same as removing these entirely.\n            if attention_mask.dtype == torch.int or attention_mask.dtype == torch.bool:\n                attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility\n                attention_mask = (1.0 - attention_mask) * -10000.0\n\n        # Prepare head mask if needed\n        # 1.0 in head_mask indicate we keep the head\n        # attention_probs has shape bsz x num_heads x N x N\n        # head_mask has shape n_layer x batch x num_heads x N x N\n        head_mask = self.get_head_mask(head_mask, self.config.num_layers)\n\n        if inputs_embeds is None:\n            inputs_embeds = self.wte(input_ids)\n        position_embeds = self.wpe(position_ids)\n        hidden_states = inputs_embeds + position_embeds\n\n        if token_type_ids is not None:\n            token_type_embeds = self.wte(token_type_ids)\n            hidden_states = hidden_states + token_type_embeds\n\n        hidden_states = self.drop(hidden_states)\n\n        output_shape = input_shape + (hidden_states.size(-1),)\n\n        # top attention custom query\n        last_layer_id = len(self.h) - 1\n        query_embeds = self.wqe(position_ids)\n\n        presents = () if use_cache else None\n        all_self_attentions = () if output_attentions else None\n        all_hidden_states = () if output_hidden_states else None\n        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):\n            # Final LayerNorm before last query layer\n            if i == last_layer_id:\n                hidden_states = self.ln_f(hidden_states)\n\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n\n                if use_cache:\n                    logger.warning(\n                        \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                    )\n                    use_cache = False\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        assert len(inputs) == 1\n                        input_0 = inputs[0]\n                        return module(**input_0, use_cache=use_cache, output_attentions=output_attentions)\n\n                    return custom_forward\n\n                outputs = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(block),\n                    {\n                        \"hidden_states\": hidden_states,\n                        \"layer_past\": None,\n                        \"attention_mask\": attention_mask,\n                        \"head_mask\": head_mask[i],\n                        \"custom_query\": query_embeds if i == last_layer_id else None\n                    }\n                )\n            else:\n                outputs = block(\n                    hidden_states,\n                    layer_past=layer_past,\n                    attention_mask=attention_mask,\n                    head_mask=head_mask[i],\n                    # custom query\n                    custom_query=query_embeds if i == last_layer_id else None,\n                    use_cache=use_cache,\n                    output_attentions=output_attentions,\n                )\n\n            hidden_states = outputs[0]\n            if use_cache is True:\n                presents = presents + (outputs[1],)\n\n            if output_attentions:\n                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)\n\n        hidden_states = hidden_states.view(*output_shape)\n        # Add last hidden state\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_states,)\n\n        if not return_dict:\n            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)\n\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=presents,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attentions,\n        )\n\n\nclass GPTPanguForCausalLM(GPTPanguPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.transformer = GPTPanguModel(config)\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):\n        token_type_ids = kwargs.get(\"token_type_ids\", None)\n        # only last token for inputs_ids if past is defined in kwargs\n        if past:\n            input_ids = input_ids[:, -1].unsqueeze(-1)\n            if token_type_ids is not None:\n                token_type_ids = token_type_ids[:, -1].unsqueeze(-1)\n\n        attention_mask = kwargs.get(\"attention_mask\", None)\n        position_ids = kwargs.get(\"position_ids\", None)\n\n        if attention_mask is not None and position_ids is None:\n            # create position_ids on the fly for batch generation\n            position_ids = attention_mask.int().cumsum(-1).long() - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            if past:\n                position_ids = position_ids[:, -1].unsqueeze(-1)\n        else:\n            position_ids = None\n        return {\n            \"input_ids\": input_ids,\n            \"past_key_values\": past,\n            \"use_cache\": kwargs.get(\"use_cache\"),\n            \"position_ids\": position_ids,\n            \"attention_mask\": attention_mask,\n            \"token_type_ids\": token_type_ids,\n        }\n\n    def forward(\n        self,\n        input_ids=None,\n        past_key_values=None,\n        attention_mask=None,\n        token_type_ids=None,\n        position_ids=None,\n        head_mask=None,\n        inputs_embeds=None,\n        labels=None,\n        use_cache=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set\n            ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to\n            ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        transformer_outputs = self.transformer(\n            input_ids,\n            past_key_values=past_key_values,\n            attention_mask=attention_mask,\n            token_type_ids=token_type_ids,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n\n        lm_logits = self.lm_head(hidden_states)\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = lm_logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)\n            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))\n\n        if not return_dict:\n            output = (lm_logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=lm_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n\n    @staticmethod\n    def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:\n        \"\"\"\n        This function is used to re-order the :obj:`past_key_values` cache if\n        :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is\n        called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.\n        \"\"\"\n        return tuple(\n            tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)\n            for layer_past in past\n        )\n"
  },
  {
    "path": "src/resources/models/pangu-13B/tokenization_gptpangu.py",
    "content": "import os\nimport torch\nimport sentencepiece\nimport jieba\nimport numpy as np\n\nfrom transformers.tokenization_utils import PreTrainedTokenizer\n\njieba.add_word('<s>')\njieba.add_word('</s>')\njieba.add_word('<eot>')\njieba.add_word('<unk>')\njieba.add_word('<sep>')\njieba.add_word('<pad>')\n\n\nclass GPTPanguTokenizer(PreTrainedTokenizer):\n    # Ref: https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha/src/branch/master/tokenization_jieba.py\n    vocab_files_names = {\n        \"model_file\": \"vocab.model\"\n    }\n\n    def __init__(\n            self,\n            model_file,\n            **kwargs\n    ):\n        super().__init__(**kwargs)\n\n        self.sp = sentencepiece.SentencePieceProcessor()\n        self.sp.Load(model_file=model_file)\n        self.translator = str.maketrans(\" \\n\", \"\\u2582\\u2583\")\n        self.vocab_file = model_file\n        # special token ids\n        # self.eos_token_id = self.sp.piece_to_id(\"<eot>\")\n\n    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):\n        \"\"\"\n        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and\n        adding special tokens. A BERT sequence has the following format:\n\n        - single sequence: `[CLS] X [SEP]`\n        - pair of sequences: `[CLS] A [SEP] B [SEP]`\n\n        Args:\n            token_ids_0 (`List[int]`):\n                List of IDs to which the special tokens will be added.\n            token_ids_1 (`List[int]`, *optional*):\n                Optional second list of IDs for sequence pairs.\n\n        Returns:\n            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.\n        \"\"\"\n        if self.bos_token_id is not None:\n            if token_ids_1 is None:\n                return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]\n            bos = [self.bos_token_id]\n            sep = [self.sep_token_id]\n            eos = [self.eos_token_id]\n            return bos + token_ids_0 + sep + token_ids_1 + eos\n        else:\n            if token_ids_1 is None:\n                return token_ids_0 + [self.eos_token_id]\n            sep = [self.sep_token_id]\n            eos = [self.eos_token_id]\n            return token_ids_0 + sep + token_ids_1 + eos\n\n    def tokenize(self, text, **kwargs):\n        \"\"\" Tokenize a string. \"\"\"\n        seg_list = [x.translate(self.translator) for x in jieba.cut(text, cut_all=False)]\n        return seg_list\n\n    def convert_tokens_to_ids(self, tokens):\n        if tokens is None:\n            return None\n\n        if isinstance(tokens, str):\n            return self._convert_token_to_id_with_added_voc(tokens)\n\n        special_tokens_index = [i for i, token in enumerate(tokens) if token in self.all_special_tokens]\n\n        ids = []\n        i = 0\n        for j in special_tokens_index:\n            new_seg = \" \".join(tokens[i:j])\n            ids.extend(self.sp.encode(new_seg))\n            ids.append(self._convert_token_to_id(tokens[j]))\n            i = j + 1\n\n        new_seg = \" \".join(tokens[i:])\n        ids.extend(self.sp.encode(new_seg))\n\n        return ids\n\n        # new_seg = \" \".join(tokens)\n        # return self.sp.encode(new_seg)\n        # # return tokens\n\n    def _convert_token_to_id(self, token):\n        return self.sp.piece_to_id(token)\n\n    def _convert_id_to_token(self, index):\n        return self.sp.id_to_piece(index)\n\n    def convert_ids_to_tokens(self, ids):\n        return self.decode(ids)\n\n    def decode(self, ids, **kwargs):\n        if isinstance(ids, torch.Tensor) or isinstance(ids, np.ndarray):\n            ids = ids.tolist()\n\n        if kwargs.get('skip_special_tokens', None) is True:\n            ids = [token_id for token_id in ids if token_id not in self.all_special_ids]\n        text = self.sp.decode(ids)\n        if isinstance(text, list):\n            text = text[0]\n        text = text.replace(' ', '').replace('\\u2582', ' ').replace('\\u2583', '\\n')#.replace('⁇', self.unk_token)\n        return text\n\n    @property\n    def vocab_size(self) -> int:\n        \"\"\"\n        `int`: Size of the base vocabulary (without the added tokens).\n        \"\"\"\n        return len(self.sp)\n\n    def save_vocabulary(self, save_directory, filename_prefix=None):\n        \"\"\"\n        Save the vocabulary and special tokens file to a directory.\n\n        Args:\n            save_directory (`str`):\n                The directory in which to save the vocabulary.\n            filename_prefix (`str`, *optional*):\n                An optional prefix to add to the named of the saved files.\n\n        Returns:\n            `Tuple(str)`: Paths to the files saved.\n        \"\"\"\n        if os.path.isdir(save_directory):\n            vocab_file = os.path.join(\n                save_directory, self.vocab_files_names[\"model_file\"]\n            )\n        else:\n            vocab_file = save_directory\n\n        with open(self.vocab_file, 'rb') as fin:\n            proto_str = fin.read()\n\n        with open(vocab_file, \"wb\") as writer:\n            writer.write(proto_str)\n\n        return (vocab_file,)"
  },
  {
    "path": "src/resources/models/pangu-13B/tokenizer_config.json",
    "content": "{\n  \"eos_token\": \"<eot>\",\n  \"pad_token\": \"<pad>\",\n  \"unk_token\": \"<unk>\",\n  \"sep_token\": \"<sep>\",\n  \"bos_token\": \"<s>\",\n  \"add_prefix_space\": false,\n  \"tokenizer_class\": \"GPTPanguTokenizer\",\n  \"use_fast\": false,\n  \"auto_map\": {\n    \"AutoTokenizer\": [\n      \"tokenization_gptpangu.GPTPanguTokenizer\",\n      null\n      ]\n  }\n}"
  },
  {
    "path": "src/resources/models/pangu-2.6B/config.json",
    "content": "{\n  \"model_name_or_path\": \"sunzeyeah/pangu-2.6B\",\n  \"activation_function\": \"gelu\",\n  \"architectures\": [\n    \"GPTPanguForCausalLM\"\n  ],\n  \"attn_pdrop\": 0.1,\n  \"embd_pdrop\": 0.1,\n  \"hidden_size\": 2560,\n  \"initializer_range\": 0.02,\n  \"intermediate_size\": null,\n  \"layer_norm_epsilon\": 1e-05,\n  \"max_position_embeddings\": 1024,\n  \"model_type\": \"pangu\",\n  \"num_heads\": 32,\n  \"num_layers\": 32,\n  \"resid_pdrop\": 0.1,\n  \"scale_attn_weights\": true,\n  \"summary_activation\": null,\n  \"summary_first_dropout\": 0.1,\n  \"summary_proj_to_labels\": true,\n  \"summary_type\": \"cls_index\",\n  \"summary_use_proj\": true,\n  \"torch_dtype\": \"float32\",\n  \"transformers_version\": \"4.13.0\",\n  \"use_cache\": true,\n  \"vocab_size\": 40000,\n  \"tokenizer_class\": \"GPTPanguTokenizer\",\n  \"auto_map\": {\n    \"AutoConfig\": \"configuration_gptpangu.GPTPanguConfig\",\n    \"AutoTokenizer\": [\"tokenization_gptpangu.GPTPanguTokenizer\", null],\n    \"AutoModelForCausalLM\": \"modeling_gptpangu.GPTPanguForCausalLM\"\n  },\n  \"pad_token_id\": 6\n}\n"
  },
  {
    "path": "src/resources/models/pangu-2.6B/configuration_gptpangu.py",
    "content": "from transformers.configuration_utils import PretrainedConfig\n\n\nclass GPTPanguConfig(PretrainedConfig):\n    model_type = \"gpt_pangu\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=40000,\n        max_position_embeddings=1024,\n        hidden_size=2560,\n        intermediate_size=None,\n        num_layers=32,\n        num_heads=32,\n        activation_function=\"gelu\",\n        resid_pdrop=0.1,\n        embd_pdrop=0.1,\n        attn_pdrop=0.1,\n        layer_norm_epsilon=1e-5,\n        scale_attn_weights=True,\n        initializer_range=0.02,\n        summary_type=\"cls_index\",\n        summary_use_proj=True,\n        summary_activation=None,\n        summary_proj_to_labels=True,\n        summary_first_dropout=0.1,\n        use_cache=True,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_layers = num_layers\n        self.num_heads = num_heads\n        self.activation_function = activation_function\n        self.resid_pdrop = resid_pdrop\n        self.embd_pdrop = embd_pdrop\n        self.attn_pdrop = attn_pdrop\n        self.layer_norm_epsilon = layer_norm_epsilon\n        self.scale_attn_weights = scale_attn_weights\n        self.initializer_range = initializer_range\n        self.summary_type = summary_type\n        self.summary_use_proj = summary_use_proj\n        self.summary_activation = summary_activation\n        self.summary_first_dropout = summary_first_dropout\n        self.summary_proj_to_labels = summary_proj_to_labels\n        self.use_cache = use_cache\n\n        super().__init__(**kwargs)\n"
  },
  {
    "path": "src/resources/models/pangu-2.6B/modeling_gptpangu.py",
    "content": "\"\"\"PyTorch PanguAlpha GPT2 Model\"\"\"\n# from .configuration_gptpangu import GPTPanguConfig\n\nfrom typing import Tuple\nimport math\n\nimport torch\nfrom torch import nn\n\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast\n\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\n\nclass GPTPanguAttention(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n\n        max_positions = config.max_position_embeddings\n        self.register_buffer(\n            \"bias\",\n            torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(\n                1, 1, max_positions, max_positions\n            ),\n        )\n        self.register_buffer(\"masked_bias\", torch.tensor(-1e4))\n\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.num_heads\n        self.head_dim = self.embed_dim // self.num_heads\n        if self.head_dim * self.num_heads != self.embed_dim:\n            raise ValueError(\n                f\"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads}).\"\n            )\n\n        self.scale_attn_weights = config.scale_attn_weights\n\n        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)\n        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)\n        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)\n        self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)\n\n        self.attn_dropout = nn.Dropout(config.attn_pdrop)\n        self.resid_dropout = nn.Dropout(config.resid_pdrop)\n\n\n    def _attn(self, query, key, value, attention_mask=None, head_mask=None):\n        attn_weights = torch.matmul(query, key.transpose(-1, -2))\n\n        if self.scale_attn_weights:\n            attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)\n\n        query_length, key_length = query.size(-2), key.size(-2)\n        causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()\n        attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))\n\n        if attention_mask is not None:\n            # Apply the attention mask\n            attn_weights = attn_weights + attention_mask\n\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1)\n\n        # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise\n        attn_weights = attn_weights.type(value.dtype)\n        attn_weights = self.attn_dropout(attn_weights)\n\n        # Mask heads if we want to\n        if head_mask is not None:\n            attn_weights = attn_weights * head_mask\n\n        attn_output = torch.matmul(attn_weights, value)\n\n        return attn_output, attn_weights\n\n    def _split_heads(self, tensor, num_heads, attn_head_size):\n        \"\"\"\n        Splits hidden_size dim into attn_head_size and num_heads\n        \"\"\"\n        new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)\n        tensor = tensor.view(*new_shape)\n        return tensor.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)\n\n    def _merge_heads(self, tensor, num_heads, attn_head_size):\n        \"\"\"\n        Merges attn_head_size dim and num_attn_heads dim into hidden_size\n        \"\"\"\n        tensor = tensor.permute(0, 2, 1, 3).contiguous()\n        new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)\n        return tensor.view(new_shape)\n\n    def forward(\n        self,\n        hidden_states,\n        layer_past=None,\n        attention_mask=None,\n        head_mask=None,\n        custom_query=None,\n        use_cache=False,\n        output_attentions=False,\n    ):\n        query = self.q_proj(custom_query) if custom_query is not None else self.q_proj(hidden_states)\n        key = self.k_proj(hidden_states)\n        value = self.v_proj(hidden_states)\n\n        query = self._split_heads(query, self.num_heads, self.head_dim)\n        key = self._split_heads(key, self.num_heads, self.head_dim)\n        value = self._split_heads(value, self.num_heads, self.head_dim)\n\n        if layer_past is not None:\n            past_key, past_value = layer_past\n            key = torch.cat((past_key, key), dim=-2)\n            value = torch.cat((past_value, value), dim=-2)\n\n        if use_cache is True:\n            present = (key, value)\n        else:\n            present = None\n\n        attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)\n\n        attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)\n        attn_output = self.c_proj(attn_output)\n        attn_output = self.resid_dropout(attn_output)\n\n        outputs = (attn_output, present)\n        if output_attentions:\n            outputs += (attn_weights,)\n\n        return outputs  # a, present, (attentions)\n\n\nclass GPTPanguMLP(nn.Module):\n    def __init__(self, intermediate_size, config):  # in MLP: intermediate_size= 4 * hidden_size\n        super().__init__()\n        embed_dim = config.hidden_size\n        self.c_fc = nn.Linear(embed_dim, intermediate_size)\n        self.c_proj = nn.Linear(intermediate_size, embed_dim)\n        self.act = ACT2FN[config.activation_function]\n        self.dropout = nn.Dropout(config.resid_pdrop)\n\n    def forward(self, hidden_states):\n        hidden_states = self.c_fc(hidden_states)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.c_proj(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        return hidden_states\n\n\nclass GPTPanguBlock(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        hidden_size = config.hidden_size\n        inner_dim = config.intermediate_size if config.intermediate_size is not None else 4 * hidden_size\n\n        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)\n        self.attn = GPTPanguAttention(config)\n        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)\n        self.mlp = GPTPanguMLP(inner_dim, config)\n\n    def forward(\n        self,\n        hidden_states,\n        layer_past=None,\n        attention_mask=None,\n        head_mask=None,\n        custom_query=None,\n        use_cache=False,\n        output_attentions=False,\n    ):\n        residual = hidden_states\n        hidden_states = self.ln_1(hidden_states)\n        attn_outputs = self.attn(\n            hidden_states,\n            layer_past=layer_past,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            custom_query=custom_query,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n        )\n        attn_output = attn_outputs[0]  # output_attn: a, present, (attentions)\n        outputs = attn_outputs[1:]\n        # residual connection\n        hidden_states = attn_output + residual\n\n        residual = hidden_states\n        hidden_states = self.ln_2(hidden_states)\n        feed_forward_hidden_states = self.mlp(hidden_states)\n        # residual connection\n        hidden_states = residual + feed_forward_hidden_states\n\n        if use_cache:\n            outputs = (hidden_states,) + outputs\n        else:\n            outputs = (hidden_states,) + outputs[1:]\n\n        return outputs  # hidden_states, present, (attentions, cross_attentions)\n\n\nclass GPTPanguPreTrainedModel(PreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained\n    models.\n    \"\"\"\n\n    # config_class = GPTPanguConfig\n    base_model_prefix = \"transformer\"\n    supports_gradient_checkpointing = True\n\n    def __init__(self, *inputs, **kwargs):\n        super().__init__(*inputs, **kwargs)\n\n    def _init_weights(self, module):\n        \"\"\"Initialize the weights.\"\"\"\n        if isinstance(module, (nn.Linear,)):\n            # Slightly different from the TF version which uses truncated_normal for initialization\n            # cf https://github.com/pytorch/pytorch/pull/5617\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n        elif isinstance(module, nn.LayerNorm):\n            module.bias.data.zero_()\n            module.weight.data.fill_(1.0)\n\n        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:\n        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale\n        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.\n        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/\n        #\n        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py\n        for name, p in module.named_parameters():\n            if \"c_proj\" in name and \"weight\" in name:\n                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block\n                p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.num_layers)))\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, GPTPanguModel):\n            module.gradient_checkpointing = value\n\n\nclass GPTPanguModel(GPTPanguPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.embed_dim = config.hidden_size\n\n        self.wte = nn.Embedding(config.vocab_size, self.embed_dim)\n        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)\n        self.wqe = nn.Embedding(config.max_position_embeddings, self.embed_dim)\n\n        self.drop = nn.Dropout(config.embd_pdrop)\n        self.h = nn.ModuleList([GPTPanguBlock(config) for _ in range(config.num_layers)])\n        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)\n\n        self.gradient_checkpointing = False\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.wte\n\n    def set_input_embeddings(self, new_embeddings):\n        self.wte = new_embeddings\n\n    def forward(\n        self,\n        input_ids=None,\n        past_key_values=None,\n        attention_mask=None,\n        token_type_ids=None,\n        position_ids=None,\n        head_mask=None,\n        inputs_embeds=None,\n        use_cache=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\"You cannot specify both input_ids and inputs_embeds at the same time\")\n        elif input_ids is not None:\n            input_shape = input_ids.size()\n            input_ids = input_ids.view(-1, input_shape[-1])\n            batch_size = input_ids.shape[0]\n        elif inputs_embeds is not None:\n            input_shape = inputs_embeds.size()[:-1]\n            batch_size = inputs_embeds.shape[0]\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n\n        device = input_ids.device if input_ids is not None else inputs_embeds.device\n\n        if token_type_ids is not None:\n            token_type_ids = token_type_ids.view(-1, input_shape[-1])\n        if position_ids is not None:\n            position_ids = position_ids.view(-1, input_shape[-1])\n\n        if past_key_values is None:\n            past_length = 0\n            past_key_values = tuple([None] * len(self.h))\n        else:\n            past_length = past_key_values[0][0].size(-2)\n        if position_ids is None:\n            position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)\n            position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])\n\n        # GPT2Attention mask.\n        if attention_mask is not None:\n            if batch_size <= 0:\n                raise ValueError(\"batch_size has to be defined and > 0\")\n            attention_mask = attention_mask.view(batch_size, -1)\n            # We create a 3D attention mask from a 2D tensor mask.\n            # Sizes are [batch_size, 1, 1, to_seq_length]\n            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]\n            # this attention mask is more simple than the triangular masking of causal attention\n            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.\n            attention_mask = attention_mask[:, None, None, :]\n\n            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for\n            # masked positions, this operation will create a tensor which is 0.0 for\n            # positions we want to attend and -10000.0 for masked positions.\n            # Since we are adding it to the raw scores before the softmax, this is\n            # effectively the same as removing these entirely.\n            if attention_mask.dtype == torch.int or attention_mask.dtype == torch.bool:\n                attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility\n                attention_mask = (1.0 - attention_mask) * -10000.0\n\n        # Prepare head mask if needed\n        # 1.0 in head_mask indicate we keep the head\n        # attention_probs has shape bsz x num_heads x N x N\n        # head_mask has shape n_layer x batch x num_heads x N x N\n        head_mask = self.get_head_mask(head_mask, self.config.num_layers)\n\n        if inputs_embeds is None:\n            inputs_embeds = self.wte(input_ids)\n        position_embeds = self.wpe(position_ids)\n        hidden_states = inputs_embeds + position_embeds\n\n        if token_type_ids is not None:\n            token_type_embeds = self.wte(token_type_ids)\n            hidden_states = hidden_states + token_type_embeds\n\n        hidden_states = self.drop(hidden_states)\n\n        output_shape = input_shape + (hidden_states.size(-1),)\n\n        # top attention custom query\n        last_layer_id = len(self.h) - 1\n        query_embeds = self.wqe(position_ids)\n\n        presents = () if use_cache else None\n        all_self_attentions = () if output_attentions else None\n        all_hidden_states = () if output_hidden_states else None\n        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):\n            # Final LayerNorm before last query layer\n            if i == last_layer_id:\n                hidden_states = self.ln_f(hidden_states)\n\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n\n                if use_cache:\n                    logger.warning(\n                        \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                    )\n                    use_cache = False\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        assert len(inputs) == 1\n                        input_0 = inputs[0]\n                        return module(**input_0, use_cache=use_cache, output_attentions=output_attentions)\n\n                    return custom_forward\n\n                outputs = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(block),\n                    {\n                        \"hidden_states\": hidden_states,\n                        \"layer_past\": None,\n                        \"attention_mask\": attention_mask,\n                        \"head_mask\": head_mask[i],\n                        \"custom_query\": query_embeds if i == last_layer_id else None\n                    }\n                )\n            else:\n                outputs = block(\n                    hidden_states,\n                    layer_past=layer_past,\n                    attention_mask=attention_mask,\n                    head_mask=head_mask[i],\n                    # custom query\n                    custom_query=query_embeds if i == last_layer_id else None,\n                    use_cache=use_cache,\n                    output_attentions=output_attentions,\n                )\n\n            hidden_states = outputs[0]\n            if use_cache is True:\n                presents = presents + (outputs[1],)\n\n            if output_attentions:\n                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)\n\n        hidden_states = hidden_states.view(*output_shape)\n        # Add last hidden state\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_states,)\n\n        if not return_dict:\n            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)\n\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=presents,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attentions,\n        )\n\n\nclass GPTPanguForCausalLM(GPTPanguPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.transformer = GPTPanguModel(config)\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):\n        token_type_ids = kwargs.get(\"token_type_ids\", None)\n        # only last token for inputs_ids if past is defined in kwargs\n        if past:\n            input_ids = input_ids[:, -1].unsqueeze(-1)\n            if token_type_ids is not None:\n                token_type_ids = token_type_ids[:, -1].unsqueeze(-1)\n\n        attention_mask = kwargs.get(\"attention_mask\", None)\n        position_ids = kwargs.get(\"position_ids\", None)\n\n        if attention_mask is not None and position_ids is None:\n            # create position_ids on the fly for batch generation\n            position_ids = attention_mask.int().cumsum(-1).long() - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            if past:\n                position_ids = position_ids[:, -1].unsqueeze(-1)\n        else:\n            position_ids = None\n        return {\n            \"input_ids\": input_ids,\n            \"past_key_values\": past,\n            \"use_cache\": kwargs.get(\"use_cache\"),\n            \"position_ids\": position_ids,\n            \"attention_mask\": attention_mask,\n            \"token_type_ids\": token_type_ids,\n        }\n\n    def forward(\n        self,\n        input_ids=None,\n        past_key_values=None,\n        attention_mask=None,\n        token_type_ids=None,\n        position_ids=None,\n        head_mask=None,\n        inputs_embeds=None,\n        labels=None,\n        use_cache=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set\n            ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to\n            ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        transformer_outputs = self.transformer(\n            input_ids,\n            past_key_values=past_key_values,\n            attention_mask=attention_mask,\n            token_type_ids=token_type_ids,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n\n        lm_logits = self.lm_head(hidden_states)\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = lm_logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)\n            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))\n\n        if not return_dict:\n            output = (lm_logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=lm_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n\n    @staticmethod\n    def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:\n        \"\"\"\n        This function is used to re-order the :obj:`past_key_values` cache if\n        :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is\n        called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.\n        \"\"\"\n        return tuple(\n            tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)\n            for layer_past in past\n        )\n"
  },
  {
    "path": "src/resources/models/pangu-2.6B/tokenization_gptpangu.py",
    "content": "import os\nimport torch\nimport sentencepiece\nimport jieba\nimport numpy as np\n\nfrom transformers.tokenization_utils import PreTrainedTokenizer\n\njieba.add_word('<s>')\njieba.add_word('</s>')\njieba.add_word('<eot>')\njieba.add_word('<unk>')\njieba.add_word('<sep>')\njieba.add_word('<pad>')\n\n\nclass GPTPanguTokenizer(PreTrainedTokenizer):\n    # Ref: https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha/src/branch/master/tokenization_jieba.py\n    vocab_files_names = {\n        \"model_file\": \"vocab.model\"\n    }\n\n    def __init__(\n            self,\n            model_file,\n            **kwargs\n    ):\n        super().__init__(**kwargs)\n\n        self.sp = sentencepiece.SentencePieceProcessor()\n        self.sp.Load(model_file=model_file)\n        self.translator = str.maketrans(\" \\n\", \"\\u2582\\u2583\")\n        self.vocab_file = model_file\n        # special token ids\n        # self.eos_token_id = self.sp.piece_to_id(\"<eot>\")\n\n    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):\n        \"\"\"\n        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and\n        adding special tokens. A BERT sequence has the following format:\n\n        - single sequence: `[CLS] X [SEP]`\n        - pair of sequences: `[CLS] A [SEP] B [SEP]`\n\n        Args:\n            token_ids_0 (`List[int]`):\n                List of IDs to which the special tokens will be added.\n            token_ids_1 (`List[int]`, *optional*):\n                Optional second list of IDs for sequence pairs.\n\n        Returns:\n            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.\n        \"\"\"\n        if self.bos_token_id is not None:\n            if token_ids_1 is None:\n                return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]\n            bos = [self.bos_token_id]\n            sep = [self.sep_token_id]\n            eos = [self.eos_token_id]\n            return bos + token_ids_0 + sep + token_ids_1 + eos\n        else:\n            if token_ids_1 is None:\n                return token_ids_0 + [self.eos_token_id]\n            sep = [self.sep_token_id]\n            eos = [self.eos_token_id]\n            return token_ids_0 + sep + token_ids_1 + eos\n\n    def tokenize(self, text, **kwargs):\n        \"\"\" Tokenize a string. \"\"\"\n        seg_list = [x.translate(self.translator) for x in jieba.cut(text, cut_all=False)]\n        return seg_list\n\n    def convert_tokens_to_ids(self, tokens):\n        if tokens is None:\n            return None\n\n        if isinstance(tokens, str):\n            return self._convert_token_to_id_with_added_voc(tokens)\n\n        special_tokens_index = [i for i, token in enumerate(tokens) if token in self.all_special_tokens]\n\n        ids = []\n        i = 0\n        for j in special_tokens_index:\n            new_seg = \" \".join(tokens[i:j])\n            ids.extend(self.sp.encode(new_seg))\n            ids.append(self._convert_token_to_id(tokens[j]))\n            i = j + 1\n\n        new_seg = \" \".join(tokens[i:])\n        ids.extend(self.sp.encode(new_seg))\n\n        return ids\n\n        # new_seg = \" \".join(tokens)\n        # return self.sp.encode(new_seg)\n        # # return tokens\n\n    def _convert_token_to_id(self, token):\n        return self.sp.piece_to_id(token)\n\n    def _convert_id_to_token(self, index):\n        return self.sp.id_to_piece(index)\n\n    def convert_ids_to_tokens(self, ids):\n        return self.decode(ids)\n\n    def decode(self, ids, **kwargs):\n        if isinstance(ids, torch.Tensor) or isinstance(ids, np.ndarray):\n            ids = ids.tolist()\n\n        if kwargs.get('skip_special_tokens', None) is True:\n            ids = [token_id for token_id in ids if token_id not in self.all_special_ids]\n        text = self.sp.decode(ids)\n        if isinstance(text, list):\n            text = text[0]\n        text = text.replace(' ', '').replace('\\u2582', ' ').replace('\\u2583', '\\n')#.replace('⁇', self.unk_token)\n        return text\n\n    @property\n    def vocab_size(self) -> int:\n        \"\"\"\n        `int`: Size of the base vocabulary (without the added tokens).\n        \"\"\"\n        return len(self.sp)\n\n    def save_vocabulary(self, save_directory, filename_prefix=None):\n        \"\"\"\n        Save the vocabulary and special tokens file to a directory.\n\n        Args:\n            save_directory (`str`):\n                The directory in which to save the vocabulary.\n            filename_prefix (`str`, *optional*):\n                An optional prefix to add to the named of the saved files.\n\n        Returns:\n            `Tuple(str)`: Paths to the files saved.\n        \"\"\"\n        if os.path.isdir(save_directory):\n            vocab_file = os.path.join(\n                save_directory, self.vocab_files_names[\"model_file\"]\n            )\n        else:\n            vocab_file = save_directory\n\n        with open(self.vocab_file, 'rb') as fin:\n            proto_str = fin.read()\n\n        with open(vocab_file, \"wb\") as writer:\n            writer.write(proto_str)\n\n        return (vocab_file,)"
  },
  {
    "path": "src/resources/models/pangu-2.6B/tokenizer_config.json",
    "content": "{\n  \"eos_token\": \"<eot>\",\n  \"pad_token\": \"<pad>\",\n  \"unk_token\": \"<unk>\",\n  \"sep_token\": \"<sep>\",\n  \"bos_token\": \"<s>\",\n  \"add_prefix_space\": false,\n  \"tokenizer_class\": \"GPTPanguTokenizer\",\n  \"use_fast\": false,\n  \"auto_map\": {\n    \"AutoTokenizer\": [\n      \"tokenization_gptpangu.GPTPanguTokenizer\",\n      null\n      ]\n  }\n}"
  },
  {
    "path": "src/resources/models/pangu-350M/README-Reward.md",
    "content": "Link to github: [here](https://github.com/sunzeyeah/RLHF)\n\n---\n\n# Model Description\n\nPangu-α is proposed by a joint technical team headed by PCNL. It was first released in [this repository](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha)  It is the first large-scale Chinese pre-trained language model with 200 billion parameters trained on 2048 Ascend processors using an automatic hybrid parallel training strategy. The whole training process is done on the “Peng Cheng Cloud Brain II” computing platform with the domestic deep learning framework called MindSpore. The PengCheng·PanGu-α pre-training model can support rich applications, has strong few-shot learning capabilities, and has outstanding performance in text generation tasks such as knowledge question and answer, knowledge retrieval, knowledge reasoning, and reading comprehension.\n\nThis repository contains PyTorch implementation of PanGu model with 350 million parameters pretrained weights (FP32 precision). It uses supervised finetuned [pangu-350M-sft](https://huggingface.co/sunzeyeah/pangu-350M-sft) and performs **reward training** on [Chinese Chatgpt Corpus](https://huggingface.co/datasets/sunzeyeah/chinese_chatgpt_corpus).\n"
  },
  {
    "path": "src/resources/models/pangu-350M/README-SFT.md",
    "content": "Link to github: [here](https://github.com/sunzeyeah/RLHF)\n\n---\n\n# Model Description\n\nPangu-α is proposed by a joint technical team headed by PCNL. It was first released in [this repository](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha)  It is the first large-scale Chinese pre-trained language model with 200 billion parameters trained on 2048 Ascend processors using an automatic hybrid parallel training strategy. The whole training process is done on the “Peng Cheng Cloud Brain II” computing platform with the domestic deep learning framework called MindSpore. The PengCheng·PanGu-α pre-training model can support rich applications, has strong few-shot learning capabilities, and has outstanding performance in text generation tasks such as knowledge question and answer, knowledge retrieval, knowledge reasoning, and reading comprehension.\n\nThis repository contains PyTorch implementation of PanGu model with 350 million parameters pretrained weights (FP32 precision). It uses pretrained pangu-350M model and perform **supervised finetuning (SFT)** on [Chinese Chatgpt Corpus](https://huggingface.co/datasets/sunzeyeah/chinese_chatgpt_corpus).\n\n---\n\n# Usage (Text Generation)\n\nCurrently PanGu model is not supported by transformers,\nso `trust_remote_code=True` is required to load model implementation in this repo.\n\n```python\nfrom transformers import TextGenerationPipeline, AutoTokenizer, AutoModelForCausalLM\n\ntokenizer = AutoTokenizer.from_pretrained(\"sunzeyeah/pangu-350M-sft\", trust_remote_code=True)\nmodel = AutoModelForCausalLM.from_pretrained(\"sunzeyeah/pangu-350M-sft\", trust_remote_code=True)\n\nprompt = \"我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点<sep>回答：\"\ninputs = tokenizer(prompt, add_special_tokens=False, return_token_type_ids=False, return_tensors=\"pt\")\noutputs = model.generate(**inputs,\n                         max_new_tokens=100,\n                         pad_token_id=tokenizer.pad_token_id,\n                         do_sample=False,\n                         num_return_sequences=1,\n                         top_p=0.8,\n                         temperature=0.8)\nresults = tokenizer.batch_decode(outputs, skip_special_tokens=True)\nresults = [result.split(\"答:\", maxsplit=1)[1] for result in results]\nprint(results)\n```\n\nExpected output:\n```python\n[\"你爱他就多关心他,少在他面前表现出你的脆弱。这样他才更会爱你。\"]\n```\n"
  },
  {
    "path": "src/resources/models/pangu-350M/README.md",
    "content": "---\nlanguage:\n- zh\ntags:\n- pangu\n- chatgpt\n---\n\nLink to github: [here](https://github.com/sunzeyeah/RLHF)\n\n---\n\n\n# Model Description\n\nPangu-α is proposed by a joint technical team headed by PCNL. It was first released in [this repository](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha)  It is the first large-scale Chinese pre-trained language model with 200 billion parameters trained on 2048 Ascend processors using an automatic hybrid parallel training strategy. The whole training process is done on the “Peng Cheng Cloud Brain II” computing platform with the domestic deep learning framework called MindSpore. The PengCheng·PanGu-α pre-training model can support rich applications, has strong few-shot learning capabilities, and has outstanding performance in text generation tasks such as knowledge question and answer, knowledge retrieval, knowledge reasoning, and reading comprehension.\n\nThis repository contains PyTorch implementation of PanGu model with 350 million parameters pretrained weights (FP32 precision). \n\nIt is slightly different from the [original pangu implementation](https://huggingface.co/imone/pangu_2_6B) to support the ChatGPT training pipeline in this github repo: [sunzeyeah/RLHF](https://github.com/sunzeyeah/RLHF).\n\n---\n"
  },
  {
    "path": "src/resources/models/pangu-350M/config.json",
    "content": "{\n  \"model_name_or_path\": \"sunzeyeah/pangu-350M\",\n  \"activation_function\": \"gelu\",\n  \"architectures\": [\n    \"GPTPanguForCausalLM\"\n  ],\n  \"attn_pdrop\": 0.1,\n  \"embd_pdrop\": 0.1,\n  \"hidden_size\": 1024,\n  \"initializer_range\": 0.02,\n  \"intermediate_size\": null,\n  \"layer_norm_epsilon\": 1e-05,\n  \"max_position_embeddings\": 1024,\n  \"model_type\": \"pangu\",\n  \"num_heads\": 16,\n  \"num_layers\": 24,\n  \"resid_pdrop\": 0.1,\n  \"scale_attn_weights\": true,\n  \"summary_activation\": null,\n  \"summary_first_dropout\": 0.1,\n  \"summary_proj_to_labels\": true,\n  \"summary_type\": \"cls_index\",\n  \"summary_use_proj\": true,\n  \"torch_dtype\": \"float32\",\n  \"vocab_size\": 40000,\n  \"tokenizer_class\": \"GPTPanguTokenizer\",\n  \"auto_map\": {\n    \"AutoConfig\": \"configuration_gptpangu.GPTPanguConfig\",\n    \"AutoTokenizer\": [\"tokenization_gptpangu.GPTPanguTokenizer\", null],\n    \"AutoModelForCausalLM\": \"modeling_gptpangu.GPTPanguForCausalLM\"\n  },\n  \"pad_token_id\": 6\n}\n"
  },
  {
    "path": "src/resources/models/pangu-350M/configuration_gptpangu.py",
    "content": "from transformers.configuration_utils import PretrainedConfig\n\n\nclass GPTPanguConfig(PretrainedConfig):\n    model_type = \"gpt_pangu\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=40000,\n        max_position_embeddings=1024,\n        hidden_size=1024,\n        intermediate_size=None,\n        num_layers=24,\n        num_heads=16,\n        activation_function=\"gelu\",\n        resid_pdrop=0.1,\n        embd_pdrop=0.1,\n        attn_pdrop=0.1,\n        layer_norm_epsilon=1e-5,\n        scale_attn_weights=True,\n        initializer_range=0.02,\n        summary_type=\"cls_index\",\n        summary_use_proj=True,\n        summary_activation=None,\n        summary_proj_to_labels=True,\n        summary_first_dropout=0.1,\n        use_cache=True,\n        # bos_token_id=9,\n        # eos_token_id=9,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_layers = num_layers\n        self.num_heads = num_heads\n        self.activation_function = activation_function\n        self.resid_pdrop = resid_pdrop\n        self.embd_pdrop = embd_pdrop\n        self.attn_pdrop = attn_pdrop\n        self.layer_norm_epsilon = layer_norm_epsilon\n        self.scale_attn_weights = scale_attn_weights\n        self.initializer_range = initializer_range\n        self.summary_type = summary_type\n        self.summary_use_proj = summary_use_proj\n        self.summary_activation = summary_activation\n        self.summary_first_dropout = summary_first_dropout\n        self.summary_proj_to_labels = summary_proj_to_labels\n        self.use_cache = use_cache\n\n        # self.bos_token_id = bos_token_id\n        # self.eos_token_id = eos_token_id\n\n        super().__init__(**kwargs)\n"
  },
  {
    "path": "src/resources/models/pangu-350M/modeling_gptpangu.py",
    "content": "\"\"\"PyTorch PanguAlpha GPT2 Model\"\"\"\n# from .configuration_gptpangu import GPTPanguConfig\n\nfrom typing import Tuple\nimport math\n\nimport torch\nfrom torch import nn\n\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast\n\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\n\nclass GPTPanguAttention(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n\n        max_positions = config.max_position_embeddings\n        self.register_buffer(\n            \"bias\",\n            torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(\n                1, 1, max_positions, max_positions\n            ),\n        )\n        self.register_buffer(\"masked_bias\", torch.tensor(-1e4))\n\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.num_heads\n        self.head_dim = self.embed_dim // self.num_heads\n        if self.head_dim * self.num_heads != self.embed_dim:\n            raise ValueError(\n                f\"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads}).\"\n            )\n\n        self.scale_attn_weights = config.scale_attn_weights\n\n        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)\n        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)\n        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)\n        self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)\n\n        self.attn_dropout = nn.Dropout(config.attn_pdrop)\n        self.resid_dropout = nn.Dropout(config.resid_pdrop)\n\n\n    def _attn(self, query, key, value, attention_mask=None, head_mask=None):\n        attn_weights = torch.matmul(query, key.transpose(-1, -2))\n\n        if self.scale_attn_weights:\n            attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)\n\n        query_length, key_length = query.size(-2), key.size(-2)\n        causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()\n        attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))\n\n        if attention_mask is not None:\n            # Apply the attention mask\n            attn_weights = attn_weights + attention_mask\n\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1)\n\n        # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise\n        attn_weights = attn_weights.type(value.dtype)\n        attn_weights = self.attn_dropout(attn_weights)\n\n        # Mask heads if we want to\n        if head_mask is not None:\n            attn_weights = attn_weights * head_mask\n\n        attn_output = torch.matmul(attn_weights, value)\n\n        return attn_output, attn_weights\n\n    def _split_heads(self, tensor, num_heads, attn_head_size):\n        \"\"\"\n        Splits hidden_size dim into attn_head_size and num_heads\n        \"\"\"\n        new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)\n        tensor = tensor.view(*new_shape)\n        return tensor.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)\n\n    def _merge_heads(self, tensor, num_heads, attn_head_size):\n        \"\"\"\n        Merges attn_head_size dim and num_attn_heads dim into hidden_size\n        \"\"\"\n        tensor = tensor.permute(0, 2, 1, 3).contiguous()\n        new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)\n        return tensor.view(new_shape)\n\n    def forward(\n        self,\n        hidden_states,\n        layer_past=None,\n        attention_mask=None,\n        head_mask=None,\n        custom_query=None,\n        use_cache=False,\n        output_attentions=False,\n    ):\n        query = self.q_proj(custom_query) if custom_query is not None else self.q_proj(hidden_states)\n        key = self.k_proj(hidden_states)\n        value = self.v_proj(hidden_states)\n\n        query = self._split_heads(query, self.num_heads, self.head_dim)\n        key = self._split_heads(key, self.num_heads, self.head_dim)\n        value = self._split_heads(value, self.num_heads, self.head_dim)\n\n        if layer_past is not None:\n            past_key, past_value = layer_past\n            key = torch.cat((past_key, key), dim=-2)\n            value = torch.cat((past_value, value), dim=-2)\n\n        if use_cache is True:\n            present = (key, value)\n        else:\n            present = None\n\n        attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)\n\n        attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)\n        attn_output = self.c_proj(attn_output)\n        attn_output = self.resid_dropout(attn_output)\n\n        outputs = (attn_output, present)\n        if output_attentions:\n            outputs += (attn_weights,)\n\n        return outputs  # a, present, (attentions)\n\n\nclass GPTPanguMLP(nn.Module):\n    def __init__(self, intermediate_size, config):  # in MLP: intermediate_size= 4 * hidden_size\n        super().__init__()\n        embed_dim = config.hidden_size\n        self.c_fc = nn.Linear(embed_dim, intermediate_size)\n        self.c_proj = nn.Linear(intermediate_size, embed_dim)\n        self.act = ACT2FN[config.activation_function]\n        self.dropout = nn.Dropout(config.resid_pdrop)\n\n    def forward(self, hidden_states):\n        hidden_states = self.c_fc(hidden_states)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.c_proj(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        return hidden_states\n\n\nclass GPTPanguBlock(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        hidden_size = config.hidden_size\n        inner_dim = config.intermediate_size if config.intermediate_size is not None else 4 * hidden_size\n\n        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)\n        self.attn = GPTPanguAttention(config)\n        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)\n        self.mlp = GPTPanguMLP(inner_dim, config)\n\n    def forward(\n        self,\n        hidden_states,\n        layer_past=None,\n        attention_mask=None,\n        head_mask=None,\n        custom_query=None,\n        use_cache=False,\n        output_attentions=False,\n    ):\n        residual = hidden_states\n        hidden_states = self.ln_1(hidden_states)\n        attn_outputs = self.attn(\n            hidden_states,\n            layer_past=layer_past,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            custom_query=custom_query,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n        )\n        attn_output = attn_outputs[0]  # output_attn: a, present, (attentions)\n        outputs = attn_outputs[1:]\n        # residual connection\n        hidden_states = attn_output + residual\n\n        residual = hidden_states\n        hidden_states = self.ln_2(hidden_states)\n        feed_forward_hidden_states = self.mlp(hidden_states)\n        # residual connection\n        hidden_states = residual + feed_forward_hidden_states\n\n        if use_cache:\n            outputs = (hidden_states,) + outputs\n        else:\n            outputs = (hidden_states,) + outputs[1:]\n\n        return outputs  # hidden_states, present, (attentions, cross_attentions)\n\n\nclass GPTPanguPreTrainedModel(PreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained\n    models.\n    \"\"\"\n\n    # config_class = GPTPanguConfig\n    base_model_prefix = \"transformer\"\n    supports_gradient_checkpointing = True\n\n    def __init__(self, *inputs, **kwargs):\n        super().__init__(*inputs, **kwargs)\n\n    def _init_weights(self, module):\n        \"\"\"Initialize the weights.\"\"\"\n        if isinstance(module, (nn.Linear,)):\n            # Slightly different from the TF version which uses truncated_normal for initialization\n            # cf https://github.com/pytorch/pytorch/pull/5617\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n        elif isinstance(module, nn.LayerNorm):\n            module.bias.data.zero_()\n            module.weight.data.fill_(1.0)\n\n        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:\n        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale\n        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.\n        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/\n        #\n        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py\n        for name, p in module.named_parameters():\n            if \"c_proj\" in name and \"weight\" in name:\n                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block\n                p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.num_layers)))\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, GPTPanguModel):\n            module.gradient_checkpointing = value\n\n\nclass GPTPanguModel(GPTPanguPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.embed_dim = config.hidden_size\n\n        self.wte = nn.Embedding(config.vocab_size, self.embed_dim)\n        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)\n        self.wqe = nn.Embedding(config.max_position_embeddings, self.embed_dim)\n\n        self.drop = nn.Dropout(config.embd_pdrop)\n        self.h = nn.ModuleList([GPTPanguBlock(config) for _ in range(config.num_layers)])\n        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)\n\n        self.gradient_checkpointing = False\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.wte\n\n    def set_input_embeddings(self, new_embeddings):\n        self.wte = new_embeddings\n\n    def forward(\n        self,\n        input_ids=None,\n        past_key_values=None,\n        attention_mask=None,\n        token_type_ids=None,\n        position_ids=None,\n        head_mask=None,\n        inputs_embeds=None,\n        use_cache=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\"You cannot specify both input_ids and inputs_embeds at the same time\")\n        elif input_ids is not None:\n            input_shape = input_ids.size()\n            input_ids = input_ids.view(-1, input_shape[-1])\n            batch_size = input_ids.shape[0]\n        elif inputs_embeds is not None:\n            input_shape = inputs_embeds.size()[:-1]\n            batch_size = inputs_embeds.shape[0]\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n\n        device = input_ids.device if input_ids is not None else inputs_embeds.device\n\n        if token_type_ids is not None:\n            token_type_ids = token_type_ids.view(-1, input_shape[-1])\n        if position_ids is not None:\n            position_ids = position_ids.view(-1, input_shape[-1])\n\n        if past_key_values is None:\n            past_length = 0\n            past_key_values = tuple([None] * len(self.h))\n        else:\n            past_length = past_key_values[0][0].size(-2)\n        if position_ids is None:\n            position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)\n            position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])\n\n        # GPT2Attention mask.\n        if attention_mask is not None:\n            if batch_size <= 0:\n                raise ValueError(\"batch_size has to be defined and > 0\")\n            attention_mask = attention_mask.view(batch_size, -1)\n            # We create a 3D attention mask from a 2D tensor mask.\n            # Sizes are [batch_size, 1, 1, to_seq_length]\n            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]\n            # this attention mask is more simple than the triangular masking of causal attention\n            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.\n            attention_mask = attention_mask[:, None, None, :]\n\n            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for\n            # masked positions, this operation will create a tensor which is 0.0 for\n            # positions we want to attend and -10000.0 for masked positions.\n            # Since we are adding it to the raw scores before the softmax, this is\n            # effectively the same as removing these entirely.\n            if attention_mask.dtype == torch.int or attention_mask.dtype == torch.bool:\n                attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility\n                attention_mask = (1.0 - attention_mask) * -10000.0\n\n        # Prepare head mask if needed\n        # 1.0 in head_mask indicate we keep the head\n        # attention_probs has shape bsz x num_heads x N x N\n        # head_mask has shape n_layer x batch x num_heads x N x N\n        head_mask = self.get_head_mask(head_mask, self.config.num_layers)\n\n        if inputs_embeds is None:\n            inputs_embeds = self.wte(input_ids)\n        position_embeds = self.wpe(position_ids)\n        hidden_states = inputs_embeds + position_embeds\n\n        if token_type_ids is not None:\n            token_type_embeds = self.wte(token_type_ids)\n            hidden_states = hidden_states + token_type_embeds\n\n        hidden_states = self.drop(hidden_states)\n\n        output_shape = input_shape + (hidden_states.size(-1),)\n\n        # top attention custom query\n        last_layer_id = len(self.h) - 1\n        query_embeds = self.wqe(position_ids)\n\n        presents = () if use_cache else None\n        all_self_attentions = () if output_attentions else None\n        all_hidden_states = () if output_hidden_states else None\n        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):\n            # Final LayerNorm before last query layer\n            if i == last_layer_id:\n                hidden_states = self.ln_f(hidden_states)\n\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n\n                if use_cache:\n                    logger.warning(\n                        \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                    )\n                    use_cache = False\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        assert len(inputs) == 1\n                        input_0 = inputs[0]\n                        return module(**input_0, use_cache=use_cache, output_attentions=output_attentions)\n\n                    return custom_forward\n\n                outputs = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(block),\n                    {\n                        \"hidden_states\": hidden_states,\n                        \"layer_past\": None,\n                        \"attention_mask\": attention_mask,\n                        \"head_mask\": head_mask[i],\n                        \"custom_query\": query_embeds if i == last_layer_id else None\n                    }\n                )\n            else:\n                outputs = block(\n                    hidden_states,\n                    layer_past=layer_past,\n                    attention_mask=attention_mask,\n                    head_mask=head_mask[i],\n                    # custom query\n                    custom_query=query_embeds if i == last_layer_id else None,\n                    use_cache=use_cache,\n                    output_attentions=output_attentions,\n                )\n\n            hidden_states = outputs[0]\n            if use_cache is True:\n                presents = presents + (outputs[1],)\n\n            if output_attentions:\n                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)\n\n        hidden_states = hidden_states.view(*output_shape)\n        # Add last hidden state\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_states,)\n\n        if not return_dict:\n            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)\n\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=presents,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attentions,\n        )\n\n\nclass GPTPanguForCausalLM(GPTPanguPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.transformer = GPTPanguModel(config)\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):\n        token_type_ids = kwargs.get(\"token_type_ids\", None)\n        # only last token for inputs_ids if past is defined in kwargs\n        if past:\n            input_ids = input_ids[:, -1].unsqueeze(-1)\n            if token_type_ids is not None:\n                token_type_ids = token_type_ids[:, -1].unsqueeze(-1)\n\n        attention_mask = kwargs.get(\"attention_mask\", None)\n        position_ids = kwargs.get(\"position_ids\", None)\n\n        if attention_mask is not None and position_ids is None:\n            # create position_ids on the fly for batch generation\n            position_ids = attention_mask.int().cumsum(-1).long() - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            if past:\n                position_ids = position_ids[:, -1].unsqueeze(-1)\n        else:\n            position_ids = None\n        return {\n            \"input_ids\": input_ids,\n            \"past_key_values\": past,\n            \"use_cache\": kwargs.get(\"use_cache\"),\n            \"position_ids\": position_ids,\n            \"attention_mask\": attention_mask,\n            \"token_type_ids\": token_type_ids,\n        }\n\n    def forward(\n        self,\n        input_ids=None,\n        past_key_values=None,\n        attention_mask=None,\n        token_type_ids=None,\n        position_ids=None,\n        head_mask=None,\n        inputs_embeds=None,\n        labels=None,\n        use_cache=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set\n            ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to\n            ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        transformer_outputs = self.transformer(\n            input_ids,\n            past_key_values=past_key_values,\n            attention_mask=attention_mask,\n            token_type_ids=token_type_ids,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n\n        lm_logits = self.lm_head(hidden_states)\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = lm_logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)\n            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))\n\n        if not return_dict:\n            output = (lm_logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=lm_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n\n    @staticmethod\n    def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:\n        \"\"\"\n        This function is used to re-order the :obj:`past_key_values` cache if\n        :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is\n        called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.\n        \"\"\"\n        return tuple(\n            tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)\n            for layer_past in past\n        )\n"
  },
  {
    "path": "src/resources/models/pangu-350M/tokenization_gptpangu.py",
    "content": "import os\nimport torch\nimport sentencepiece\nimport jieba\nimport numpy as np\n\nfrom transformers.tokenization_utils import PreTrainedTokenizer\n\njieba.add_word('<s>')\njieba.add_word('</s>')\njieba.add_word('<eot>')\njieba.add_word('<unk>')\njieba.add_word('<sep>')\njieba.add_word('<pad>')\n\n\nclass GPTPanguTokenizer(PreTrainedTokenizer):\n    # Ref: https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha/src/branch/master/tokenization_jieba.py\n    vocab_files_names = {\n        \"model_file\": \"vocab.model\"\n    }\n\n    def __init__(\n            self,\n            model_file,\n            **kwargs\n    ):\n        super().__init__(**kwargs)\n\n        self.sp = sentencepiece.SentencePieceProcessor()\n        self.sp.Load(model_file=model_file)\n        self.translator = str.maketrans(\" \\n\", \"\\u2582\\u2583\")\n        self.vocab_file = model_file\n        # special token ids\n        # self.eos_token_id = self.sp.piece_to_id(\"<eot>\")\n\n    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):\n        \"\"\"\n        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and\n        adding special tokens. A BERT sequence has the following format:\n\n        - single sequence: `[CLS] X [SEP]`\n        - pair of sequences: `[CLS] A [SEP] B [SEP]`\n\n        Args:\n            token_ids_0 (`List[int]`):\n                List of IDs to which the special tokens will be added.\n            token_ids_1 (`List[int]`, *optional*):\n                Optional second list of IDs for sequence pairs.\n\n        Returns:\n            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.\n        \"\"\"\n        if self.bos_token_id is not None:\n            if token_ids_1 is None:\n                return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]\n            bos = [self.bos_token_id]\n            sep = [self.sep_token_id]\n            eos = [self.eos_token_id]\n            return bos + token_ids_0 + sep + token_ids_1 + eos\n        else:\n            if token_ids_1 is None:\n                return token_ids_0 + [self.eos_token_id]\n            sep = [self.sep_token_id]\n            eos = [self.eos_token_id]\n            return token_ids_0 + sep + token_ids_1 + eos\n\n    def tokenize(self, text, **kwargs):\n        \"\"\" Tokenize a string. \"\"\"\n        seg_list = [x.translate(self.translator) for x in jieba.cut(text, cut_all=False)]\n        return seg_list\n\n    def convert_tokens_to_ids(self, tokens):\n        if tokens is None:\n            return None\n\n        if isinstance(tokens, str):\n            return self._convert_token_to_id_with_added_voc(tokens)\n\n        special_tokens_index = [i for i, token in enumerate(tokens) if token in self.all_special_tokens]\n\n        ids = []\n        i = 0\n        for j in special_tokens_index:\n            new_seg = \" \".join(tokens[i:j])\n            ids.extend(self.sp.encode(new_seg))\n            ids.append(self._convert_token_to_id(tokens[j]))\n            i = j + 1\n\n        new_seg = \" \".join(tokens[i:])\n        ids.extend(self.sp.encode(new_seg))\n\n        return ids\n\n        # new_seg = \" \".join(tokens)\n        # return self.sp.encode(new_seg)\n        # # return tokens\n\n    def _convert_token_to_id(self, token):\n        return self.sp.piece_to_id(token)\n\n    def _convert_id_to_token(self, index):\n        return self.sp.id_to_piece(index)\n\n    def convert_ids_to_tokens(self, ids):\n        return self.decode(ids)\n\n    def decode(self, ids, **kwargs):\n        if isinstance(ids, torch.Tensor) or isinstance(ids, np.ndarray):\n            ids = ids.tolist()\n\n        if kwargs.get('skip_special_tokens', None) is True:\n            ids = [token_id for token_id in ids if token_id not in self.all_special_ids]\n        text = self.sp.decode(ids)\n        if isinstance(text, list):\n            text = text[0]\n        text = text.replace(' ', '').replace('\\u2582', ' ').replace('\\u2583', '\\n')#.replace('⁇', self.unk_token)\n        return text\n\n    @property\n    def vocab_size(self) -> int:\n        \"\"\"\n        `int`: Size of the base vocabulary (without the added tokens).\n        \"\"\"\n        return len(self.sp)\n\n    def save_vocabulary(self, save_directory, filename_prefix=None):\n        \"\"\"\n        Save the vocabulary and special tokens file to a directory.\n\n        Args:\n            save_directory (`str`):\n                The directory in which to save the vocabulary.\n            filename_prefix (`str`, *optional*):\n                An optional prefix to add to the named of the saved files.\n\n        Returns:\n            `Tuple(str)`: Paths to the files saved.\n        \"\"\"\n        if os.path.isdir(save_directory):\n            vocab_file = os.path.join(\n                save_directory, self.vocab_files_names[\"model_file\"]\n            )\n        else:\n            vocab_file = save_directory\n\n        with open(self.vocab_file, 'rb') as fin:\n            proto_str = fin.read()\n\n        with open(vocab_file, \"wb\") as writer:\n            writer.write(proto_str)\n\n        return (vocab_file,)"
  },
  {
    "path": "src/resources/models/pangu-350M/tokenizer_config.json",
    "content": "{\n  \"eos_token\": \"<eot>\",\n  \"pad_token\": \"<pad>\",\n  \"unk_token\": \"<unk>\",\n  \"sep_token\": \"<sep>\",\n  \"bos_token\": \"<s>\",\n  \"add_prefix_space\": false,\n  \"tokenizer_class\": \"GPTPanguTokenizer\",\n  \"use_fast\": false,\n  \"auto_map\": {\n    \"AutoTokenizer\": [\n      \"tokenization_gptpangu.GPTPanguTokenizer\",\n      null\n      ]\n  }\n}"
  },
  {
    "path": "src/resources/models/qwen-7B-chat/config.json",
    "content": "{\n  \"activation\": \"swiglu\",\n  \"apply_residual_connection_post_layernorm\": false,\n  \"architectures\": [\n    \"QWenLMHeadModel\"\n  ],\n  \"auto_map\": {\n    \"AutoConfig\": \"configuration_qwen.QWenConfig\",\n    \"AutoModelForCausalLM\": \"modeling_qwen.QWenLMHeadModel\"\n  },\n  \"attn_pdrop\": 0.0,\n  \"bf16\": false,\n  \"bias_dropout_fusion\": true,\n  \"bos_token_id\": 151643,\n  \"embd_pdrop\": 0.1,\n  \"eos_token_id\": 151643,\n  \"ffn_hidden_size\": 22016,\n  \"fp16\": false,\n  \"initializer_range\": 0.02,\n  \"kv_channels\": 128,\n  \"layer_norm_epsilon\": 1e-05,\n  \"model_type\": \"qwen\",\n  \"n_embd\": 4096,\n  \"n_head\": 32,\n  \"n_layer\": 32,\n  \"n_positions\": 6144,\n  \"no_bias\": true,\n  \"onnx_safe\": null,\n  \"padded_vocab_size\": 151936,\n  \"params_dtype\": \"torch.bfloat16\",\n  \"pos_emb\": \"rotary\",\n  \"resid_pdrop\": 0.1,\n  \"rotary_emb_base\": 10000,\n  \"rotary_pct\": 1.0,\n  \"scale_attn_weights\": true,\n  \"seq_length\": 2048,\n  \"tie_word_embeddings\": false,\n  \"tokenizer_type\": \"QWenTokenizer\",\n  \"transformers_version\": \"4.31.0\",\n  \"use_cache\": true,\n  \"use_flash_attn\": true,\n  \"vocab_size\": 151936,\n  \"use_dynamic_ntk\": true,\n  \"use_logn_attn\": true\n}\n"
  },
  {
    "path": "src/resources/models/qwen-7B-chat/configuration_qwen.py",
    "content": "# Copyright (c) Alibaba Cloud.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nfrom transformers import PretrainedConfig\n\n\nclass QWenConfig(PretrainedConfig):\n    model_type = \"qwen\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n    attribute_map = {\n        \"hidden_size\": \"n_embd\",\n        \"num_attention_heads\": \"n_head\",\n        \"max_position_embeddings\": \"n_positions\",\n        \"num_hidden_layers\": \"n_layer\",\n    }\n\n    def __init__(\n        self,\n        vocab_size=151851,\n        n_embd=4096,\n        n_layer=32,\n        n_head=32,\n        n_inner=None,\n        embd_pdrop=0.0,\n        attn_pdrop=0.0,\n        layer_norm_epsilon=1e-5,\n        initializer_range=0.02,\n        scale_attn_weights=True,\n        use_cache=True,\n        eos_token_id=151643,\n        apply_residual_connection_post_layernorm=False,\n        bf16=True,\n        kv_channels=128,\n        rotary_pct=1.0,\n        rotary_emb_base=10000,\n        use_dynamic_ntk=False,\n        use_logn_attn=False,\n        use_flash_attn=True,\n        ffn_hidden_size=22016,\n        no_bias=True,\n        tie_word_embeddings=False,\n        **kwargs,\n    ):\n        self.eos_token_id = eos_token_id\n        super().__init__(\n            eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs\n        )\n\n        self.vocab_size = vocab_size\n        self.n_embd = n_embd\n        self.n_layer = n_layer\n        self.n_head = n_head\n        self.n_inner = n_inner\n        self.embd_pdrop = embd_pdrop\n        self.attn_pdrop = attn_pdrop\n        self.layer_norm_epsilon = layer_norm_epsilon\n        self.initializer_range = initializer_range\n        self.scale_attn_weights = scale_attn_weights\n        self.use_cache = use_cache\n        self.apply_residual_connection_post_layernorm = (\n            apply_residual_connection_post_layernorm\n        )\n        self.bf16 = bf16\n        self.kv_channels = kv_channels\n        self.rotary_pct = rotary_pct\n        self.rotary_emb_base = rotary_emb_base\n        self.use_dynamic_ntk = use_dynamic_ntk\n        self.use_logn_attn = use_logn_attn\n        self.use_flash_attn = use_flash_attn\n        self.ffn_hidden_size = ffn_hidden_size\n        self.no_bias = no_bias\n        self.tie_word_embeddings = tie_word_embeddings\n"
  },
  {
    "path": "src/resources/models/qwen-7B-chat/generation_config.json",
    "content": "{\n  \"chat_format\": \"chatml\",\n  \"decay_bound\": 0.0,\n  \"decay_factor\": 1.0,\n  \"eos_token_id\": 151643,\n  \"factual_nucleus_sampling\": false,\n  \"max_context_size\": 1024,\n  \"max_generate_size\": 512,\n  \"max_new_tokens\": 512,\n  \"pad_token_id\": 151643,\n  \"stop_words_ids\": [[151643]],\n  \"do_sample\": true,\n  \"top_k\": 0,\n  \"top_p\": 0.8,\n  \"transformers_version\": \"4.31.0\"\n}\n"
  },
  {
    "path": "src/resources/models/qwen-7B-chat/modeling_qwen.py",
    "content": "# Copyright (c) Alibaba Cloud.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport importlib\nimport math\nfrom typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List\n\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom torch.cuda.amp import autocast\n\nfrom torch.nn import CrossEntropyLoss\nfrom transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList\nfrom transformers.generation.logits_process import LogitsProcessorList\nif TYPE_CHECKING:\n    from transformers.generation.streamers import BaseStreamer\nfrom transformers.generation.utils import GenerateOutput\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n)\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.utils import logging\n\ntry:\n    from einops import rearrange\nexcept ImportError:\n    rearrange = None\nfrom torch import nn\n\ntry:\n    from flash_attn.layers.rotary import apply_rotary_emb_func\n    from einops import rearrange\n\n    use_flash_rotary = True\nexcept ImportError:\n    use_flash_rotary = False\n    print(\"Warning: import flash_attn rotary fail, please install FlashAttention rotary to get better performance \"\n          \"https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary\")\n\ntry:\n    from flash_attn.ops.rms_norm import rms_norm\nexcept ImportError:\n    rms_norm = None\n    print(\"Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get better performance \"\n          \"https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm\")\n\nfrom .configuration_qwen import QWenConfig\nfrom .qwen_generation_utils import (\n    HistoryType,\n    make_context,\n    decode_tokens,\n    get_stop_words_ids,\n    StopWordsLogitsProcessor,\n)\n\n\nlogger = logging.get_logger(__name__)\n\n_CHECKPOINT_FOR_DOC = \"qwen\"\n_CONFIG_FOR_DOC = \"QWenConfig\"\n\nQWen_PRETRAINED_MODEL_ARCHIVE_LIST = [\"qwen-7b\"]\n\ntry:\n    from flash_attn.flash_attn_interface import flash_attn_unpadded_func\nexcept ImportError:\n    flash_attn_unpadded_func = None\n    print(\"Warning: import flash_attn fail, please install FlashAttention \"\n          \"https://github.com/Dao-AILab/flash-attention\")\n\n\nclass FlashSelfAttention(torch.nn.Module):\n    def __init__(\n        self,\n        causal=False,\n        softmax_scale=None,\n        attention_dropout=0.0,\n    ):\n        super().__init__()\n        assert flash_attn_unpadded_func is not None, (\n            \"Please install FlashAttention first, \" \"e.g., with pip install flash-attn\"\n        )\n        assert (\n            rearrange is not None\n        ), \"Please install einops first, e.g., with pip install einops\"\n        self.causal = causal\n        self.softmax_scale = softmax_scale\n        self.dropout_p = attention_dropout\n\n    def forward(self, q, k, v):\n        assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)))\n        assert all((i.is_cuda for i in (q, k, v)))\n        batch_size, seqlen_q = q.shape[0], q.shape[1]\n        seqlen_k = k.shape[1]\n        q, k, v = [rearrange(x, \"b s ... -> (b s) ...\") for x in [q, k, v]]\n        cu_seqlens_q = torch.arange(\n            0,\n            (batch_size + 1) * seqlen_q,\n            step=seqlen_q,\n            dtype=torch.int32,\n            device=q.device,\n        )\n\n        if self.training:\n            assert seqlen_k == seqlen_q\n\n            is_causal = self.causal\n            cu_seqlens_k = cu_seqlens_q\n        else:\n            is_causal = seqlen_q == seqlen_k\n            cu_seqlens_k = torch.arange(\n                0,\n                (batch_size + 1) * seqlen_k,\n                step=seqlen_k,\n                dtype=torch.int32,\n                device=q.device,\n            )\n            self.dropout_p = 0\n        output = flash_attn_unpadded_func(\n            q,\n            k,\n            v,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            seqlen_q,\n            seqlen_k,\n            self.dropout_p,\n            softmax_scale=self.softmax_scale,\n            causal=is_causal,\n        )\n\n        output = rearrange(output, \"(b s) ... -> b s ...\", b=batch_size)\n        return output\n\n\nclass QWenAttention(nn.Module):\n    def __init__(self, config, layer_number=None):\n        super().__init__()\n\n        max_positions = config.max_position_embeddings\n        self.register_buffer(\n            \"bias\",\n            torch.tril(\n                torch.ones((max_positions, max_positions), dtype=torch.bool)\n            ).view(1, 1, max_positions, max_positions),\n            persistent=False,\n        )\n        self.register_buffer(\"masked_bias\", torch.tensor(-1e4), persistent=False)\n        self.layer_number = max(1, layer_number)\n        self.params_dtype = config.params_dtype\n        self.seq_length = config.seq_length\n\n        self.hidden_size = config.hidden_size\n        self.split_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.hidden_size // self.num_heads\n\n        self.use_flash_attn = config.use_flash_attn\n        self.scale_attn_weights = True\n\n        self.layer_idx = None\n\n        self.projection_size = config.kv_channels * config.num_attention_heads\n\n        assert self.projection_size % config.num_attention_heads == 0\n        self.hidden_size_per_attention_head = (\n            self.projection_size // config.num_attention_heads\n        )\n\n        self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)\n\n        self.c_proj = nn.Linear(\n            config.hidden_size, self.projection_size, bias=not config.no_bias\n        )\n\n        self.is_fp32 = not(config.bf16 or config.fp16)\n        if self.use_flash_attn and flash_attn_unpadded_func is not None and not self.is_fp32:\n            self.core_attention_flash = FlashSelfAttention(\n                causal=True, attention_dropout=config.attn_pdrop\n            )\n\n        self.bf16 = config.bf16\n\n        if config.rotary_pct == 1.0:\n            self.rotary_ndims = None\n        else:\n            assert config.rotary_pct < 1\n            self.rotary_ndims = int(\n                self.hidden_size_per_attention_head * config.rotary_pct\n            )\n        dim = (\n            self.rotary_ndims\n            if self.rotary_ndims is not None\n            else self.hidden_size_per_attention_head\n        )\n        self.rotary_emb = RotaryEmbedding(\n            dim, base=config.rotary_emb_base\n        )\n\n        self.use_dynamic_ntk = config.use_dynamic_ntk\n        self.use_logn_attn = config.use_logn_attn\n\n        logn_list = [math.log(i, self.seq_length) if i > self.seq_length else 1 for i in range(1, 32768)]\n        self.logn_tensor = torch.Tensor(logn_list)[None, :, None, None]\n        self._ntk_cached = 1.0\n\n        self.attn_dropout = nn.Dropout(config.attn_pdrop)\n\n    def _attn(self, query, key, value, attention_mask=None, head_mask=None):\n        attn_weights = torch.matmul(query, key.transpose(-1, -2))\n\n        if self.scale_attn_weights:\n            attn_weights = attn_weights / torch.full(\n                [],\n                value.size(-1) ** 0.5,\n                dtype=attn_weights.dtype,\n                device=attn_weights.device,\n            )\n\n        query_length, key_length = query.size(-2), key.size(-2)\n        causal_mask = self.bias[\n            :, :, key_length - query_length : key_length, :key_length\n        ]\n        mask_value = torch.finfo(attn_weights.dtype).min\n        mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(\n            attn_weights.device\n        )\n        attn_weights = torch.where(\n            causal_mask, attn_weights.to(attn_weights.dtype), mask_value\n        )\n\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1)\n\n        attn_weights = attn_weights.type(value.dtype)\n        attn_weights = self.attn_dropout(attn_weights)\n\n        if head_mask is not None:\n            attn_weights = attn_weights * head_mask\n\n        attn_output = torch.matmul(attn_weights, value)\n        attn_output = attn_output.transpose(1, 2)\n\n        return attn_output, attn_weights\n\n    def _upcast_and_reordered_attn(\n        self, query, key, value, attention_mask=None, head_mask=None\n    ):\n        bsz, num_heads, q_seq_len, dk = query.size()\n        _, _, k_seq_len, _ = key.size()\n\n        attn_weights = torch.empty(\n            bsz * num_heads,\n            q_seq_len,\n            k_seq_len,\n            dtype=torch.float32,\n            device=query.device,\n        )\n\n        scale_factor = 1.0\n        if self.scale_attn_weights:\n            scale_factor /= float(value.size(-1)) ** 0.5\n\n        with autocast(enabled=False):\n            q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(\n                -1, dk, k_seq_len\n            )\n            attn_weights = torch.baddbmm(\n                attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor\n            )\n            attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)\n\n        query_length, key_length = query.size(-2), key.size(-2)\n        causal_mask = self.bias[\n            :, :, key_length - query_length : key_length, :key_length\n        ]\n        mask_value = torch.finfo(attn_weights.dtype).min\n        mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(\n            attn_weights.device\n        )\n        attn_weights = torch.where(causal_mask, attn_weights, mask_value)\n\n        if attention_mask is not None:\n            attn_weights = attn_weights + attention_mask\n\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1)\n\n        if attn_weights.dtype != torch.float32:\n            raise RuntimeError(\n                \"Error with upcasting, attn_weights does not have dtype torch.float32\"\n            )\n        attn_weights = attn_weights.type(value.dtype)\n        attn_weights = self.attn_dropout(attn_weights)\n\n        if head_mask is not None:\n            attn_weights = attn_weights * head_mask\n\n        attn_output = torch.matmul(attn_weights, value)\n\n        return attn_output, attn_weights\n\n    def _split_heads(self, tensor, num_heads, attn_head_size):\n        new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)\n        tensor = tensor.view(new_shape)\n        return tensor\n\n    def _merge_heads(self, tensor, num_heads, attn_head_size):\n        tensor = tensor.contiguous()\n        new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)\n        return tensor.view(new_shape)\n\n    def forward(\n        self,\n        hidden_states: Optional[Tuple[torch.FloatTensor]],\n        layer_past: Optional[Tuple[torch.Tensor]] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n    ):\n\n        mixed_x_layer = self.c_attn(hidden_states)\n        query, key, value = mixed_x_layer.split(self.split_size, dim=2)\n\n        query = self._split_heads(query, self.num_heads, self.head_dim)\n        key = self._split_heads(key, self.num_heads, self.head_dim)\n        value = self._split_heads(value, self.num_heads, self.head_dim)\n\n        kv_seq_len = hidden_states.size()[1]\n        if layer_past:\n            # layer past[0] shape: bs * seq_len * head_num * dim\n            kv_seq_len += layer_past[0].shape[1]\n        if self.use_dynamic_ntk and kv_seq_len == hidden_states.size()[1] and not self.training:\n            context_value = math.log(kv_seq_len / self.seq_length, 2) + 1\n            ntk_alpha = 2 ** math.ceil(context_value) - 1\n            ntk_alpha = max(ntk_alpha, 1)\n            self._ntk_cached = ntk_alpha\n        else:\n            ntk_alpha = self._ntk_cached\n        rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha).to(hidden_states.device)\n\n        if rotary_pos_emb is not None:\n            if isinstance(rotary_pos_emb, tuple):\n                rotary_pos_emb = rotary_pos_emb\n            else:\n                rotary_pos_emb = (rotary_pos_emb,) * 2\n\n        if rotary_pos_emb is not None:\n            q_pos_emb, k_pos_emb = rotary_pos_emb\n            # Slice the pos emb for current inference\n            cur_len = query.shape[1]\n            q_pos_emb = q_pos_emb[:, -cur_len:, :, :]\n            k_pos_emb = k_pos_emb[:, -cur_len:, :, :]\n            query = apply_rotary_pos_emb(query, q_pos_emb)\n            key = apply_rotary_pos_emb(key, k_pos_emb)\n\n        if layer_past is not None:\n            past_key, past_value = layer_past[0], layer_past[1]\n            key = torch.cat((past_key, key), dim=1)\n            value = torch.cat((past_value, value), dim=1)\n\n        if use_cache:\n            present = (key, value)\n        else:\n            present = None\n\n        if self.use_logn_attn and not self.training:\n            if self.logn_tensor.device != query.device:\n                self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)\n            seq_start = key.size(1) - query.size(1)\n            seq_end = key.size(1)\n            logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]\n            query = query * logn_tensor.expand_as(query)\n\n        if self.use_flash_attn and flash_attn_unpadded_func is not None and not self.is_fp32:\n            q, k, v = query, key, value\n            context_layer = self.core_attention_flash(q, k, v)\n\n            context_layer = rearrange(\n                context_layer, \"b s h d -> b s (h d)\"\n            ).contiguous()\n        else:\n            query = query.permute(0, 2, 1, 3)\n            key = key.permute(0, 2, 1, 3)\n            value = value.permute(0, 2, 1, 3)\n            attn_output, attn_weight = self._attn(\n                query, key, value, attention_mask, head_mask\n            )\n            context_layer = self._merge_heads(\n                attn_output, self.num_heads, self.head_dim\n            )\n\n        attn_output = self.c_proj(context_layer)\n        outputs = (attn_output, present)\n        if output_attentions:\n            if self.use_flash_attn and flash_attn_unpadded_func is not None and not self.is_fp32:\n                raise ValueError(\"Cannot output attentions while using flash-attn\")\n            else:\n                outputs += (attn_weight,)\n\n        return outputs\n\n\nclass QWenMLP(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.w1 = nn.Linear(\n            config.hidden_size, config.ffn_hidden_size // 2, bias=not config.no_bias\n        )\n        self.w2 = nn.Linear(\n            config.hidden_size, config.ffn_hidden_size // 2, bias=not config.no_bias\n        )\n        ff_dim_in = config.ffn_hidden_size // 2\n        self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)\n\n    def forward(self, hidden_states):\n        a1 = self.w1(hidden_states)\n        a2 = self.w2(hidden_states)\n        intermediate_parallel = a1 * F.silu(a2)\n        output = self.c_proj(intermediate_parallel)\n        return output\n\n\nclass QWenBlock(nn.Module):\n    def __init__(self, config, layer_idx=None, num_expert=1):\n        super().__init__()\n        self.num_expert = num_expert\n        self.layer_number = layer_idx\n        self.apply_residual_connection_post_layernorm = (\n            config.apply_residual_connection_post_layernorm\n        )\n        hidden_size = config.hidden_size\n        self.apply_residual_connection_post_layernorm = (\n            config.apply_residual_connection_post_layernorm\n        )\n        self.bf16 = config.bf16\n\n        self.ln_1 = RMSNorm(\n            hidden_size,\n            eps=config.layer_norm_epsilon,\n        )\n        self.attn = QWenAttention(config, layer_number=layer_idx)\n        self.ln_2 = RMSNorm(\n            hidden_size,\n            eps=config.layer_norm_epsilon,\n        )\n\n        self.mlp = QWenMLP(config)\n\n    def forward(\n        self,\n        hidden_states: Optional[Tuple[torch.FloatTensor]],\n        layer_past: Optional[Tuple[torch.Tensor]] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = False,\n        output_attentions: Optional[bool] = False,\n    ):\n        layernorm_output = self.ln_1(hidden_states)\n\n        attn_outputs = self.attn(\n            layernorm_output,\n            layer_past=layer_past,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n        )\n        attn_output = attn_outputs[0]\n\n        outputs = attn_outputs[1:]\n\n        if self.apply_residual_connection_post_layernorm:\n            residual = layernorm_output\n        else:\n            residual = hidden_states\n        layernorm_input = attn_output + residual\n\n        layernorm_output = self.ln_2(layernorm_input)\n\n        if self.apply_residual_connection_post_layernorm:\n            residual = layernorm_output\n        else:\n            residual = layernorm_input\n\n        mlp_output = self.mlp(layernorm_output)\n        hidden_states = residual + mlp_output\n\n        if use_cache:\n            outputs = (hidden_states,) + outputs\n        else:\n            outputs = (hidden_states,) + outputs[1:]\n\n        return outputs\n\n\nclass QWenPreTrainedModel(PreTrainedModel):\n    config_class = QWenConfig\n    base_model_prefix = \"transformer\"\n    is_parallelizable = False\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"QWenBlock\"]\n\n    def __init__(self, *inputs, **kwargs):\n        super().__init__(*inputs, **kwargs)\n\n    def _init_weights(self, module):\n        \"\"\"Initialize the weights.\"\"\"\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n        elif isinstance(module, RMSNorm):\n            module.weight.data.fill_(1.0)\n\n        for name, p in module.named_parameters():\n            if name == \"c_proj.weight\":\n                p.data.normal_(\n                    mean=0.0,\n                    std=(\n                        self.config.initializer_range\n                        / math.sqrt(2 * self.config.n_layer)\n                    ),\n                )\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, QWenModel):\n            module.gradient_checkpointing = value\n\n\nclass QWenModel(QWenPreTrainedModel):\n    _keys_to_ignore_on_load_missing = [\"attn.masked_bias\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.vocab_size = config.padded_vocab_size\n        self.num_hidden_layers = config.num_hidden_layers\n        self.embed_dim = config.hidden_size\n\n        max_sequence_length = config.max_position_embeddings\n        self.position_embedding_type = config.pos_emb\n        self.gradient_checkpointing = False\n\n        if self.position_embedding_type == \"learned\":\n            self.wpe = nn.Embedding(max_sequence_length, self.embed_dim)\n            self.init_method(self.position_embeddings.weight)\n            self._position_embeddings_key = \"position_embeddings\"\n            self.init_method(self.position_embeddings.weight)\n        else:\n            self.wpe = None\n            self._position_embeddings_key = \"\"\n\n        self.wte = nn.Embedding(self.vocab_size, self.embed_dim)\n\n        self.drop = nn.Dropout(config.embd_pdrop)\n        self.h = nn.ModuleList(\n            [\n                QWenBlock(\n                    config,\n                    layer_idx=i,\n                )\n                for i in range(config.num_hidden_layers)\n            ]\n        )\n        self.ln_f = RMSNorm(\n            self.embed_dim,\n            eps=config.layer_norm_epsilon,\n        )\n\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.wte\n\n    def set_input_embeddings(self, new_embeddings):\n        self.wte = new_embeddings\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        token_type_ids: Optional[torch.LongTensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ):\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time\"\n            )\n        elif input_ids is not None:\n            input_shape = input_ids.size()\n            input_ids = input_ids.view(-1, input_shape[-1])\n            batch_size = input_ids.shape[0]\n        elif inputs_embeds is not None:\n            input_shape = inputs_embeds.size()[:-1]\n            batch_size = inputs_embeds.shape[0]\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n\n        device = input_ids.device if input_ids is not None else inputs_embeds.device\n\n        if token_type_ids is not None:\n            token_type_ids = token_type_ids.view(-1, input_shape[-1])\n        if position_ids is not None:\n            position_ids = position_ids.view(-1, input_shape[-1])\n\n        if past_key_values is None:\n            past_length = 0\n            past_key_values = tuple([None] * len(self.h))\n        else:\n            past_length = past_key_values[0][0].size(-2)\n\n        if position_ids is None:\n            position_ids = torch.arange(\n                past_length,\n                input_shape[-1] + past_length,\n                dtype=torch.long,\n                device=device,\n            )\n            position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])\n\n        if attention_mask is not None:\n            if batch_size <= 0:\n                raise ValueError(\"batch_size has to be defined and > 0\")\n            attention_mask = attention_mask.view(batch_size, -1)\n            attention_mask = attention_mask[:, None, None, :]\n            attention_mask = attention_mask.to(dtype=self.dtype)\n            attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min\n\n        encoder_attention_mask = None\n        head_mask = self.get_head_mask(head_mask, self.config.n_layer)\n\n        if inputs_embeds is None:\n            inputs_embeds = self.wte(input_ids)\n        hidden_states = inputs_embeds\n        if self.wpe is not None:\n            position_embeds = self.wpe(position_ids)\n            hidden_states = hidden_states + position_embeds\n\n        hidden_states = self.drop(hidden_states)\n        output_shape = input_shape + (hidden_states.size(-1),)\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                logger.warning_once(\n                    \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                )\n                use_cache = False\n\n        presents = () if use_cache else None\n        all_self_attentions = () if output_attentions else None\n        all_hidden_states = () if output_hidden_states else None\n        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):\n\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        # None for past_key_value\n                        return module(*inputs, use_cache, output_attentions)\n\n                    return custom_forward\n\n                outputs = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(block),\n                    hidden_states,\n                    None,\n                    attention_mask,\n                    head_mask[i],\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                )\n            else:\n                outputs = block(\n                    hidden_states,\n                    layer_past=layer_past,\n                    attention_mask=attention_mask,\n                    head_mask=head_mask[i],\n                    encoder_hidden_states=encoder_hidden_states,\n                    encoder_attention_mask=encoder_attention_mask,\n                    use_cache=use_cache,\n                    output_attentions=output_attentions,\n                )\n\n            hidden_states = outputs[0]\n            if use_cache is True:\n                presents = presents + (outputs[2 if output_attentions else 1],)\n\n            if output_attentions:\n                all_self_attentions = all_self_attentions + (outputs[1],)\n\n        hidden_states = self.ln_f(hidden_states)\n        hidden_states = hidden_states.view(output_shape)\n\n        if not return_dict:\n            return tuple(\n                v for v in [hidden_states, presents, all_hidden_states] if v is not None\n            )\n\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=presents,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attentions,\n        )\n\n\nclass QWenLMHeadModel(QWenPreTrainedModel):\n    _keys_to_ignore_on_load_missing = [r\"h\\.\\d+\\.attn\\.rotary_emb\\.inv_freq\"]\n    _keys_to_ignore_on_load_unexpected = [r\"h\\.\\d+\\.attn\\.masked_bias\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.transformer = QWenModel(config)\n        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n        assert not(config.bf16 and config.fp16), (\"In config, bf16 and fp16 cannot both be true\")\n        if config.bf16:\n            self.transformer.bfloat16()\n            self.lm_head.bfloat16()\n        if config.fp16:\n            self.transformer.half()\n            self.lm_head.half()\n        self.post_init()\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def prepare_inputs_for_generation(\n        self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs\n    ):\n        token_type_ids = kwargs.get(\"token_type_ids\", None)\n        if past_key_values:\n            input_ids = input_ids[:, -1].unsqueeze(-1)\n            if token_type_ids is not None:\n                token_type_ids = token_type_ids[:, -1].unsqueeze(-1)\n\n        attention_mask = kwargs.get(\"attention_mask\", None)\n        position_ids = kwargs.get(\"position_ids\", None)\n\n        if attention_mask is not None and position_ids is None:\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            if past_key_values:\n                position_ids = position_ids[:, -1].unsqueeze(-1)\n        else:\n            position_ids = None\n\n        if inputs_embeds is not None and past_key_values is None:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        model_inputs.update(\n            {\n                \"past_key_values\": past_key_values,\n                \"use_cache\": kwargs.get(\"use_cache\"),\n                \"position_ids\": position_ids,\n                \"attention_mask\": attention_mask,\n                \"token_type_ids\": token_type_ids,\n            }\n        )\n        return model_inputs\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        token_type_ids: Optional[torch.LongTensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        transformer_outputs = self.transformer(\n            input_ids,\n            past_key_values=past_key_values,\n            attention_mask=attention_mask,\n            token_type_ids=token_type_ids,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            encoder_hidden_states=encoder_hidden_states,\n            encoder_attention_mask=encoder_attention_mask,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n\n        lm_logits = self.lm_head(hidden_states)\n\n        loss = None\n        if labels is not None:\n            labels = labels.to(lm_logits.device)\n            shift_logits = lm_logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            loss_fct = CrossEntropyLoss()\n            loss = loss_fct(\n                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)\n            )\n\n        if not return_dict:\n            output = (lm_logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=lm_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n\n    @staticmethod\n    def _reorder_cache(\n        past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor\n    ) -> Tuple[Tuple[torch.Tensor]]:\n\n        return tuple(\n            tuple(\n                past_state.index_select(0, beam_idx.to(past_state.device))\n                for past_state in layer_past\n            )\n            for layer_past in past_key_values\n        )\n\n    def chat(\n        self,\n        tokenizer: PreTrainedTokenizer,\n        query: str,\n        history: Optional[HistoryType],\n        system: str = \"You are a helpful assistant.\",\n        append_history: bool = True,\n    ) -> Tuple[str, HistoryType]:\n\n        if history is None:\n            history = []\n\n        raw_text, context_tokens = make_context(\n            tokenizer,\n            query,\n            history=history,\n            system=system,\n            max_window_size=6144,\n            chat_format=self.generation_config.chat_format,\n        )\n\n        stop_words_ids = get_stop_words_ids(\n            self.generation_config.chat_format, tokenizer\n        )\n        input_ids = torch.tensor([context_tokens]).to(self.device)\n\n        outputs = self.generate(\n            input_ids,\n            stop_words_ids=stop_words_ids,\n            return_dict_in_generate=False,\n        )\n\n        response = decode_tokens(\n            outputs[0],\n            tokenizer,\n            raw_text_len=len(raw_text),\n            context_length=len(context_tokens),\n            chat_format=self.generation_config.chat_format,\n            verbose=False,\n        )\n\n        if append_history:\n            history.append((query, response))\n\n        return response, history\n\n    def generate(\n        self,\n        inputs: Optional[torch.Tensor] = None,\n        generation_config: Optional[GenerationConfig] = None,\n        logits_processor: Optional[LogitsProcessorList] = None,\n        stopping_criteria: Optional[StoppingCriteriaList] = None,\n        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,\n        synced_gpus: Optional[bool] = None,\n        streamer: Optional[\"BaseStreamer\"] = None,\n        **kwargs,\n    ) -> Union[GenerateOutput, torch.LongTensor]:\n        # Process stop_words_ids.\n        stop_words_ids = kwargs.pop('stop_words_ids', None)\n        if stop_words_ids is None and generation_config is not None:\n            stop_words_ids = getattr(generation_config, 'stop_words_ids', None)\n        if stop_words_ids is None:\n            stop_words_ids = getattr(self.generation_config, 'stop_words_ids', None)\n\n        if stop_words_ids is not None:\n            stop_words_logits_processor = StopWordsLogitsProcessor(\n                stop_words_ids=stop_words_ids, eos_token_id=self.generation_config.eos_token_id)\n            if logits_processor is None:\n                logits_processor = LogitsProcessorList([stop_words_logits_processor])\n            else:\n                logits_processor.append(stop_words_logits_processor)\n\n        return super().generate(\n            inputs,\n            generation_config,\n            logits_processor,\n            stopping_criteria,\n            prefix_allowed_tokens_fn,\n            synced_gpus,\n            streamer,\n            **kwargs,\n        )\n\n\nclass RotaryEmbedding(torch.nn.Module):\n    def __init__(self, dim, base=10000):\n        super().__init__()\n        self.dim = dim\n        self.base = base\n        self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))\n        if importlib.util.find_spec(\"einops\") is None:\n            raise RuntimeError(\"einops is required for Rotary Embedding\")\n\n        self._rotary_pos_emb_cache = None\n        self._seq_len_cached = 0\n        self._ntk_alpha_cached = 1.0\n\n    def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):\n        seqlen = max_seq_len + offset\n        if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:\n            base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))\n            self.inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=self.inv_freq.device).float() / self.dim))\n            self._seq_len_cached = seqlen\n            self._ntk_alpha_cached = ntk_alpha\n            seq = torch.arange(seqlen, device=self.inv_freq.device)\n            freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            from einops import rearrange\n\n            self._rotary_pos_emb_cache = rearrange(emb, \"n d -> 1 n 1 d\")\n\n    def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):\n        self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)\n        return self._rotary_pos_emb_cache[:, offset : offset + max_seq_len]\n\n\ndef _rotate_half(x):\n    from einops import rearrange\n\n    x = rearrange(x, \"... (j d) -> ... j d\", j=2)\n    x1, x2 = x.unbind(dim=-2)\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(t, freqs, use_flash_rotary=False):\n    if use_flash_rotary:\n        t_ = t.float()\n        freqs = freqs.squeeze(0).squeeze(1)\n        cos = freqs[:, : freqs.shape[-1] // 2].cos()\n        sin = freqs[:, : freqs.shape[-1] // 2].sin()\n        output = apply_rotary_emb_func(t_, cos, sin).type_as(t)\n        return output\n    else:\n        rot_dim = freqs.shape[-1]\n        t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]\n        t_ = t_.float()\n        t_pass_ = t_pass_.float()\n        t_ = (t_ * freqs.cos()) + (_rotate_half(t_) * freqs.sin())\n        return torch.cat((t_, t_pass_), dim=-1).type_as(t)\n\n\nclass RMSNorm(torch.nn.Module):\n    def __init__(self, dim: int, eps: float = 1e-6):\n        super().__init__()\n        self.eps = eps\n        self.weight = nn.Parameter(torch.ones(dim))\n\n    def _norm(self, x):\n        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n\n    def forward(self, x):\n        if rms_norm is not None:\n            return rms_norm(x, self.weight, self.eps)\n        else:\n            output = self._norm(x.float()).type_as(x)\n            return output * self.weight\n"
  },
  {
    "path": "src/resources/models/qwen-7B-chat/qwen_generation_utils.py",
    "content": "# Copyright (c) Alibaba Cloud.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n\"\"\"Generation support.\"\"\"\n\nfrom typing import Tuple, List, Union, Iterable\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom transformers import PreTrainedTokenizer\nfrom transformers import logging\nfrom transformers.generation import LogitsProcessor\n\nlogger = logging.get_logger(__name__)\n\n# Types.\nHistoryType = List[Tuple[str, str]]\nTokensType = List[int]\nBatchTokensType = List[List[int]]\n\n\ndef pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType:\n    for tokens in batch:\n        context_length = len(tokens)\n        if context_length < seq_length:\n            tokens.extend([pad_id] * (seq_length - context_length))\n    return batch\n\n\ndef get_ltor_masks_and_position_ids(\n    data,\n    eod_token,\n    reset_position_ids,\n    reset_attention_mask,\n    eod_mask_loss,\n):\n    \"\"\"Build masks and position id for left to right model.\"\"\"\n\n    # Extract batch size and sequence length.\n    micro_batch_size, seq_length = data.size()\n\n    # Attention mask (lower triangular).\n    if reset_attention_mask:\n        att_mask_batch = micro_batch_size\n    else:\n        att_mask_batch = 1\n    attention_mask = torch.tril(\n        torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)\n    ).view(att_mask_batch, 1, seq_length, seq_length)\n\n    # Loss mask.\n    loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)\n    if eod_mask_loss:\n        loss_mask[data == eod_token] = 0.0\n\n    # Position ids.\n    position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)\n    position_ids = position_ids.unsqueeze(0).expand_as(data)\n    # We need to clone as the ids will be modifed based on batch index.\n    if reset_position_ids:\n        position_ids = position_ids.clone()\n\n    if reset_position_ids or reset_attention_mask:\n        # Loop through the batches:\n        for b in range(micro_batch_size):\n\n            # Find indecies where EOD token is.\n            eod_index = position_ids[b, data[b] == eod_token]\n            # Detach indecies from positions if going to modify positions.\n            if reset_position_ids:\n                eod_index = eod_index.clone()\n\n            # Loop through EOD indecies:\n            prev_index = 0\n            for j in range(eod_index.size()[0]):\n                i = eod_index[j]\n                # Mask attention loss.\n                if reset_attention_mask:\n                    attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0\n                # Reset positions.\n                if reset_position_ids:\n                    position_ids[b, (i + 1) :] -= i + 1 - prev_index\n                    prev_index = i + 1\n\n    # Convert attention mask to binary:\n    attention_mask = attention_mask < 0.5\n\n    return attention_mask, loss_mask, position_ids\n\n\ndef get_batch(context_tokens: torch.LongTensor, eod_id: int):\n    \"\"\"Generate batch from context tokens.\"\"\"\n    # Move to GPU.\n    tokens = context_tokens.contiguous().to(context_tokens.device)\n    # Get the attention mask and postition ids.\n    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(\n        tokens,\n        eod_id,\n        reset_position_ids=False,\n        reset_attention_mask=False,\n        eod_mask_loss=False,\n    )\n    return tokens, attention_mask, position_ids\n\n\ndef get_stop_words_ids(chat_format, tokenizer):\n    if chat_format == \"raw\":\n        stop_words_ids = [tokenizer.encode(\"Human:\"), [tokenizer.eod_id]]\n    elif chat_format == \"chatml\":\n        stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]\n    else:\n        raise NotImplementedError(f\"Unknown chat format {chat_format!r}\")\n    return stop_words_ids\n\n\ndef make_context(\n    tokenizer: PreTrainedTokenizer,\n    query: str,\n    history: List[Tuple[str, str]] = None,\n    system: str = \"\",\n    max_window_size: int = 6144,\n    chat_format: str = \"chatml\",\n):\n    if history is None:\n        history = []\n\n    if chat_format == \"chatml\":\n        im_start, im_end = \"<|im_start|>\", \"<|im_end|>\"\n        im_start_tokens = [tokenizer.im_start_id]\n        im_end_tokens = [tokenizer.im_end_id]\n        nl_tokens = tokenizer.encode(\"\\n\", add_special_tokens=False)\n\n        def _tokenize_str(role, content):\n            return f\"{role}\\n{content}\", tokenizer.encode(\n                role, add_special_tokens=False\n            ) + nl_tokens + tokenizer.encode(content, add_special_tokens=False)\n\n        system_text, system_tokens_part = _tokenize_str(\"system\", system)\n        system_tokens = im_start_tokens + system_tokens_part + im_end_tokens\n\n        raw_text = \"\"\n        context_tokens = []\n\n        for turn_query, turn_response in reversed(history):\n            query_text, query_tokens_part = _tokenize_str(\"user\", turn_query)\n            query_tokens = im_start_tokens + query_tokens_part + im_end_tokens\n            response_text, response_tokens_part = _tokenize_str(\n                \"assistant\", turn_response\n            )\n            response_tokens = im_start_tokens + response_tokens_part + im_end_tokens\n\n            next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens\n            prev_chat = (\n                f\"\\n{im_start}{query_text}{im_end}\\n{im_start}{response_text}{im_end}\"\n            )\n\n            current_context_size = (\n                    len(system_tokens) + len(next_context_tokens) + len(context_tokens)\n            )\n            if current_context_size < max_window_size:\n                context_tokens = next_context_tokens + context_tokens\n                raw_text = prev_chat + raw_text\n            else:\n                break\n\n        context_tokens = system_tokens + context_tokens\n        raw_text = f\"{im_start}{system_text}{im_end}\" + raw_text\n        context_tokens += (\n                nl_tokens\n                + im_start_tokens\n                + _tokenize_str(\"user\", query)[1]\n                + im_end_tokens\n                + nl_tokens\n                + im_start_tokens\n                + tokenizer.encode(\"assistant\", add_special_tokens=False)\n                + nl_tokens\n        )\n        raw_text += f\"\\n{im_start}user\\n{query}{im_end}\\n{im_start}assistant\\n\"\n\n    elif chat_format == \"raw\":\n        raw_text = query\n        context_tokens = tokenizer.encode(raw_text)\n    else:\n        raise NotImplementedError(f\"Unknown chat format {chat_format!r}\")\n\n    return raw_text, context_tokens\n\n\ndef _decode_default(\n    tokens: List[int],\n    *,\n    stop_words: List[str],\n    eod_words: List[str],\n    tokenizer: PreTrainedTokenizer,\n    raw_text_len: int,\n    verbose: bool = False,\n    return_end_reason: bool = False,\n):\n    trim_decode_tokens = tokenizer.decode(tokens)[raw_text_len:]\n    if verbose:\n        print(\"\\nRaw Generate: \", trim_decode_tokens)\n\n    end_reason = f\"Gen length {len(tokens)}\"\n    for stop_word in stop_words:\n        trim_decode_tokens = trim_decode_tokens.replace(stop_word, \"\").strip()\n    for eod_word in eod_words:\n        if eod_word in trim_decode_tokens:\n            end_reason = f\"Gen {eod_word!r}\"\n        trim_decode_tokens = trim_decode_tokens.split(eod_word)[0]\n    trim_decode_tokens = trim_decode_tokens.strip()\n    if verbose:\n        print(\"\\nEnd Reason:\", end_reason)\n        print(\"\\nGenerate: \", trim_decode_tokens)\n\n    if return_end_reason:\n        return trim_decode_tokens, end_reason\n    else:\n        return trim_decode_tokens\n\n\ndef _decode_chatml(\n    tokens: List[int],\n    *,\n    stop_words: List[str],\n    eod_token_ids: List[int],\n    tokenizer: PreTrainedTokenizer,\n    raw_text_len: int,\n    context_length: int,\n    verbose: bool = False,\n    return_end_reason: bool = False,\n):\n    end_reason = f\"Gen length {len(tokens)}\"\n    eod_token_idx = context_length\n    for eod_token_idx in range(context_length, len(tokens)):\n        if tokens[eod_token_idx] in eod_token_ids:\n            end_reason = f\"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}\"\n            break\n\n    trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx])[raw_text_len:]\n    if verbose:\n        print(\"\\nRaw Generate w/o EOD:\", tokenizer.decode(tokens)[raw_text_len:])\n        print(\"\\nRaw Generate:\", trim_decode_tokens)\n        print(\"\\nEnd Reason:\", end_reason)\n    for stop_word in stop_words:\n        trim_decode_tokens = trim_decode_tokens.replace(stop_word, \"\").strip()\n    trim_decode_tokens = trim_decode_tokens.strip()\n    if verbose:\n        print(\"\\nGenerate:\", trim_decode_tokens)\n\n    if return_end_reason:\n        return trim_decode_tokens, end_reason\n    else:\n        return trim_decode_tokens\n\n\ndef decode_tokens(\n    tokens: Union[torch.LongTensor, TokensType],\n    tokenizer: PreTrainedTokenizer,\n    raw_text_len: int,\n    context_length: int,\n    chat_format: str,\n    verbose: bool = False,\n    return_end_reason: bool = False,\n) -> str:\n    if torch.is_tensor(tokens):\n        tokens = tokens.cpu().numpy().tolist()\n\n    if chat_format == \"chatml\":\n        return _decode_chatml(\n            tokens,\n            stop_words=[],\n            eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],\n            tokenizer=tokenizer,\n            raw_text_len=raw_text_len,\n            context_length=context_length,\n            verbose=verbose,\n            return_end_reason=return_end_reason,\n        )\n    elif chat_format == \"raw\":\n        return _decode_default(\n            tokens,\n            stop_words=[\"<|endoftext|>\"],\n            eod_words=[\"<|endoftext|>\"],\n            tokenizer=tokenizer,\n            raw_text_len=raw_text_len,\n            verbose=verbose,\n            return_end_reason=return_end_reason,\n        )\n    else:\n        raise NotImplementedError(f\"Unknown chat format {chat_format!r}\")\n\n\nclass StopWordsLogitsProcessor(LogitsProcessor):\n    \"\"\"\n    :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration.\n\n    Args:\n        stop_words_ids (:obj:`List[List[int]]`):\n            List of list of token ids of stop ids. In order to get the tokens of the words\n            that should not appear in the generated text, use :obj:`tokenizer(bad_word,\n            add_prefix_space=True).input_ids`.\n        eos_token_id (:obj:`int`):\n            The id of the `end-of-sequence` token.\n    \"\"\"\n\n    def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):\n\n        if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:\n            raise ValueError(\n                f\"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}.\"\n            )\n        if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):\n            raise ValueError(\n                f\"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}.\"\n            )\n        if any(\n            any(\n                (not isinstance(token_id, (int, np.integer)) or token_id < 0)\n                for token_id in stop_word_ids\n            )\n            for stop_word_ids in stop_words_ids\n        ):\n            raise ValueError(\n                f\"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}.\"\n            )\n\n        self.stop_words_ids = list(\n            filter(\n                lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids\n            )\n        )\n        self.eos_token_id = eos_token_id\n        for stop_token_seq in self.stop_words_ids:\n            assert (\n                len(stop_token_seq) > 0\n            ), \"Stop words token sequences {} cannot have an empty list\".format(\n                stop_words_ids\n            )\n\n    def __call__(\n        self, input_ids: torch.LongTensor, scores: torch.FloatTensor\n    ) -> torch.FloatTensor:\n        stopped_samples = self._calc_stopped_samples(input_ids)\n        for i, should_stop in enumerate(stopped_samples):\n            if should_stop:\n                scores[i, self.eos_token_id] = float(2**15)\n        return scores\n\n    def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:\n        if len(tokens) == 0:\n            # if bad word tokens is just one token always ban it\n            return True\n        elif len(tokens) > len(prev_tokens):\n            # if bad word tokens are longer then prev input_ids they can't be equal\n            return False\n        elif prev_tokens[-len(tokens) :].tolist() == tokens:\n            # if tokens match\n            return True\n        else:\n            return False\n\n    def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]:\n        stopped_samples = []\n        for prev_input_ids_slice in prev_input_ids:\n            match = False\n            for stop_token_seq in self.stop_words_ids:\n                if self._tokens_match(prev_input_ids_slice, stop_token_seq):\n                    # if tokens do not match continue\n                    match = True\n                    break\n            stopped_samples.append(match)\n\n        return stopped_samples\n\n\ndef top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float(\"Inf\")):\n    \"\"\"This function has been mostly taken from huggingface conversational\n    ai code at\n        https://medium.com/huggingface/how-to-build-a-state-of-the-art-\n             conversational-ai-with-transfer-learning-2d818ac26313\"\"\"\n\n    if top_k > 0:\n        # Remove all tokens with a probability less than the\n        # last token of the top-k\n        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]\n        logits[indices_to_remove] = filter_value\n\n    if top_p > 0.0:\n        # Cconvert to 1D\n        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)\n        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n\n        # Remove tokens with cumulative probability above the threshold\n        sorted_indices_to_remove = cumulative_probs > top_p\n        # Shift the indices to the right to keep also the first token\n        # above the threshold\n        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()\n        sorted_indices_to_remove[..., 0] = 0\n        for i in range(sorted_indices.size(0)):\n            indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]\n            logits[i][indices_to_remove] = filter_value\n\n    return logits\n\n\ndef switch(val1, val2, boolean):\n    boolean = boolean.type_as(val1)\n    return (1 - boolean) * val1 + boolean * val2\n"
  },
  {
    "path": "src/resources/models/qwen-7B-chat/tokenization_qwen.py",
    "content": "# Copyright (c) Alibaba Cloud.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n\"\"\"Tokenization classes for QWen.\"\"\"\n\nfrom __future__ import absolute_import, division, print_function, unicode_literals\n\nimport json\nimport logging\nimport os\nimport unicodedata\nfrom io import open\nimport base64\nimport tiktoken\nfrom typing import List, Optional, Tuple, Union\n\nfrom transformers import PreTrainedTokenizer, AddedToken\n\nlogger = logging.getLogger(__name__)\n\nVOCAB_FILES_NAMES = {\"vocab_file\": \"qwen.tiktoken\"}\n\n\nclass QWenTokenizer(PreTrainedTokenizer):\n    \"\"\"QWen tokenizer.\"\"\"\n\n    \"\"\"NOTE: This tokenizer will not handle special tokens to avoid injection attacks\"\"\"\n\n    vocab_files_names = VOCAB_FILES_NAMES\n\n    def __init__(\n        self,\n        vocab_file,\n        errors=\"replace\",\n        max_len=None,\n        unk_token=\"<|endoftext|>\",\n        bos_token=\"<|endoftext|>\",\n        eos_token=\"<|endoftext|>\",\n        pad_token=None,\n        add_prefix_space=False,\n        add_bos_token=False,\n        add_more_sp_tokens=True,\n        **kwargs,\n    ):\n        bos_token = (\n            AddedToken(bos_token, lstrip=False, rstrip=False)\n            if isinstance(bos_token, str)\n            else bos_token\n        )\n        eos_token = (\n            AddedToken(eos_token, lstrip=False, rstrip=False)\n            if isinstance(eos_token, str)\n            else eos_token\n        )\n        unk_token = (\n            AddedToken(unk_token, lstrip=False, rstrip=False)\n            if isinstance(unk_token, str)\n            else unk_token\n        )\n        pad_token = (\n            AddedToken(pad_token, lstrip=False, rstrip=False)\n            if isinstance(pad_token, str)\n            else pad_token\n        )\n        super().__init__(\n            errors=errors,\n            unk_token=unk_token,\n            bos_token=bos_token,\n            eos_token=eos_token,\n            pad_token=pad_token,\n            add_prefix_space=add_prefix_space,\n            add_bos_token=add_bos_token,\n        )\n        self.add_bos_token = add_bos_token\n        self.max_len = max_len if max_len is not None else int(1e12)\n\n        self.errors = errors  # how to handle errors in decoding\n\n        name = \"Qwen\"\n        ENDOFTEXT = \"<|endoftext|>\"\n        IMSTART = \"<|im_start|>\"\n        IMEND = \"<|im_end|>\"\n        if add_more_sp_tokens:\n            special_tokens = (\n                ENDOFTEXT,\n                IMSTART,\n                IMEND,\n                \"<R>\",\n                \"<S>\",\n                \"<X>\",\n                \"<mask>\",\n                \"<sep>\",\n            ) + tuple([f\"<extra_{i}>\" for i in range(200)])\n        else:\n            special_tokens = (ENDOFTEXT, IMSTART, IMEND)\n\n        PAT_STR = r\"\"\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+\"\"\"\n\n        def load_tiktoken_bpe(tiktoken_bpe_file: str) -> \"dict[bytes, int]\":\n            contents = open(tiktoken_bpe_file, \"rb\").read()\n            return {\n                base64.b64decode(token): int(rank)\n                for token, rank in (\n                    line.split() for line in contents.splitlines() if line\n                )\n            }\n\n        mergeable_ranks = load_tiktoken_bpe(vocab_file)\n        special_tokens = {\n            token: index\n            for index, token in enumerate(special_tokens, start=len(mergeable_ranks))\n        }\n        self.special_tokens = special_tokens\n        enc = tiktoken.Encoding(\n            name,\n            pat_str=PAT_STR,\n            mergeable_ranks=mergeable_ranks,\n            special_tokens=special_tokens,\n        )\n        assert (\n            len(mergeable_ranks) + len(special_tokens) == enc.n_vocab\n        ), f\"{len(mergeable_ranks) + len(special_tokens)} != {enc.n_vocab} in encoding\"\n\n        self.mergeable_ranks = mergeable_ranks\n        self.encoder = self.mergeable_ranks\n        self.decoder = {v: k for k, v in self.encoder.items()}\n        self.tokenizer = enc  # type: tiktoken.Encoding\n        self.eod_id = self.tokenizer.eot_token\n        self.im_start_id = special_tokens[IMSTART]\n        self.im_end_id = special_tokens[IMEND]\n\n    def __len__(self):\n        return self.tokenizer.n_vocab\n\n    def get_vocab(self):\n        return self.mergeable_ranks\n\n    def convert_tokens_to_ids(self, tokens):\n        ids = []\n        # Remove support for py2\n        if isinstance(tokens, str):\n            if tokens in self.special_tokens:\n                return self.special_tokens[tokens]\n            else:\n                return self.encoder.get(tokens)\n        for token in tokens:\n            if token in self.special_tokens:\n                ids.append(self.special_tokens[token])\n            else:\n                ids.append(self.encoder.get(token))\n        if len(ids) > self.max_len:\n            logger.warning(\n                \"Token indices sequence length is longer than the specified maximum \"\n                \" sequence length for this model ({} > {}). Running this\"\n                \" sequence through the model will result in indexing errors\".format(\n                    len(ids), self.max_len\n                )\n            )\n        return ids\n\n    def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:\n        \"\"\"\n        Save only the vocabulary of the tokenizer (vocabulary + added tokens).\n\n        Returns:\n            `Tuple(str)`: Paths to the files saved.\n        \"\"\"\n        file_path = os.path.join(save_directory, \"qwen.tiktoken\")\n        with open(file_path, \"w\", encoding=\"utf8\") as w:\n            for k, v in self.mergeable_ranks.items():\n                line = base64.b64encode(k).decode(\"utf8\") + \" \" + str(v) + \"\\n\"\n                w.write(line)\n        return (file_path,)\n\n    def tokenize(self, text: str, **kwargs) -> List[str]:\n        \"\"\"\n        Converts a string in a sequence of tokens, replacing unknown tokens with the `unk_token`.\n\n        Args:\n            text (`str`):\n                The sequence to be encoded.\n            kwargs (additional keyword arguments, *optional*):\n                Will be passed to the underlying model specific encode method. See details in\n                [`~PreTrainedTokenizerBase.__call__`]\n\n        Returns:\n            `List[str]`: The list of tokens.\n        \"\"\"\n        tokens = []\n        text = unicodedata.normalize(\"NFC\", text)\n        for t in self.tokenizer.encode_ordinary(text):\n            tokens.append(self.decoder[t])\n        return tokens\n\n    def convert_tokens_to_string(self, tokens: List[str]) -> str:\n        \"\"\"\n        Converts a sequence of tokens in a single string. The most simple way to do it is `\" \".join(tokens)` but we\n        often want to remove sub-word tokenization artifacts at the same time.\n        \"\"\"\n        text = \"\".join(tokens)\n        text = bytearray([self.byte_decoder[c] for c in text]).decode(\n            \"utf-8\", errors=self.errors\n        )\n        return text\n\n    @property\n    def vocab_size(self):\n        return self.tokenizer.n_vocab\n\n    def _convert_id_to_token(self, index: int) -> str:\n        if index >= self.tokenizer.n_vocab:\n            return self.unk_token\n        return self.tokenizer.decode([index])\n\n    def _convert_token_to_id(self, token: str) -> int:\n        \"\"\"Converts a token to an id using the vocab.\"\"\"\n        return self.encoder.get(token.encode('UTF-8'), self.tokenizer.encode(self.unk_token, allowed_special='all')[0])\n\n    @property\n    def all_special_tokens(self) -> List[str]:\n        \"\"\"\n        `List[str]`: All the special tokens (`'<unk>'`, `'<cls>'`, etc.) mapped to class attributes.\n\n        Convert tokens of `tokenizers.AddedToken` type to string.\n        \"\"\"\n        all_toks = [str(s) for s in self.special_tokens.keys()]\n        return all_toks\n\n    @property\n    def all_special_ids(self) -> List[int]:\n        \"\"\"\n        `List[int]`: List the ids of the special tokens(`'<unk>'`, `'<cls>'`, etc.) mapped to class attributes.\n        \"\"\"\n        all_ids = [v for v in self.special_tokens.values()]\n        return all_ids\n\n    def _tokenize(self, text, **kwargs):\n        \"\"\"\n        Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based\n        vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).\n\n        Do NOT take care of added tokens.\n        \"\"\"\n        raise NotImplementedError\n\n    def _decode(\n        self,\n        token_ids: Union[int, List[int]],\n        skip_special_tokens: bool = False,\n        **kwargs,\n    ) -> str:\n        if isinstance(token_ids, int):\n            token_ids = [token_ids]\n        if skip_special_tokens:\n            token_ids = [i for i in token_ids if i not in self.all_special_ids]\n        return self.tokenizer.decode(token_ids)\n\n    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):\n        # bos_token_id = [self.im_start_id]\n        # eos_token_id = [self.im_end_id]\n        bos_token_id = [self.bos_token_id]\n        eos_token_id = [self.eos_token_id]\n        line_break_token_id = [self._convert_token_to_id(\"\\n\")]\n\n        output = bos_token_id + token_ids_0 + eos_token_id + line_break_token_id\n\n        if token_ids_1 is not None:\n            output += bos_token_id + token_ids_1 + eos_token_id + line_break_token_id\n\n        return output\n"
  },
  {
    "path": "src/resources/models/qwen-7B-chat/tokenizer_config.json",
    "content": "{\n  \"remove_space\": false,\n  \"do_lower_case\": false,\n  \"tokenizer_class\": \"QWenTokenizer\",\n  \"auto_map\": {\n    \"AutoTokenizer\": [\n      \"tokenization_qwen.QWenTokenizer\",\n      null\n      ]\n  },\n  \"bos_token\": {\n      \"__type\": \"AddedToken\",\n      \"content\": \"<|im_start|>\",\n      \"lstrip\": false,\n      \"normalized\": true,\n      \"rstrip\": false,\n      \"single_word\": false\n    },\n  \"eos_token\": {\n    \"__type\": \"AddedToken\",\n    \"content\": \"<|im_end|>\",\n    \"lstrip\": false,\n    \"normalized\": true,\n    \"rstrip\": false,\n    \"single_word\": false\n  }\n}\n"
  },
  {
    "path": "src/train_dpo.py",
    "content": "import sys\n\nsys.path.insert(0, \"/root/autodl-tmp/Code/RLHF\")\nsys.path.insert(0, \"/mnt/sfevol775196/sunzeye273/Code/chatgpt\")\n# sys.path.insert(0, \"/mnt/share-pa002-vol682688-prd/sunzeye273/Code/chatgpt\")\nsys.path.insert(0, \"/mnt/pa002-28359-vol543625-private/Code/chatgpt\")\nimport os\nimport argparse\nimport evaluate\nimport torch\nimport copy\n\nfrom torch.utils.data import SequentialSampler, DataLoader\nfrom tqdm import tqdm\nfrom transformers import (\n    TrainingArguments,\n    default_data_collator,\n)\n\nfrom src.utils import RESOURCE_PATH, load_tokenizer_and_model, load_checkpoint\nfrom src.data.data import DPODataset, SFTDataset\nfrom src.utils.file_utils import set_seed, print_rank_0\nfrom src.models.trainer import DPOTrainer\n\n\n# Create a preprocessing function to extract out the proper logits from the model output\ndef preprocess_logits_for_metrics(logits, labels):\n    if isinstance(logits, tuple):\n        logits = logits[0]\n\n    return logits.argmax(dim=-1)\n\n\ndef get_parser():\n    parser = argparse.ArgumentParser()\n    \n    parser.add_argument(\"--data_dir\", type=str, required=True)\n    parser.add_argument(\"--output_dir\", type=str, required=True)\n    parser.add_argument(\"--tokenizer_path\", type=str, required=True)\n    parser.add_argument(\"--model_name_or_path\", type=str, required=True)\n\n    parser.add_argument(\"--reference_model_name_or_path\", type=str, default=None)\n    parser.add_argument(\"--seed\", type=int, default=42)\n    parser.add_argument(\"--local_rank\", type=int, default=0)\n    parser.add_argument(\"--max_length\", type=int, default=1024)\n    parser.add_argument(\"--max_length_prompt\", type=int, default=512)\n    parser.add_argument(\"--max_length_generation\", type=int, default=None)\n    parser.add_argument(\"--bits\", type=int, default=32,\n                        help=\"bits used to load model, including: 32, 16, 8, 4\")\n    parser.add_argument(\"--device_map\", type=str, default=None, help=\"device map to allocate model,\"\n                                                                     \"[None] means cpu\"\n                                                                     \"[0, 1, 2, ...], number means single-card\"\n                                                                     \"[auto, balanced, balanced_low_0] means multi-card\")\n    parser.add_argument(\"--low_cpu_mem_usage\", action=\"store_true\", help=\"whether to enable low cpu memory usage\"\n                                                                         \"when loading model\")\n    # train\n    parser.add_argument(\"--do_train\", action=\"store_true\")\n    parser.add_argument(\"--train_filename\", type=str, default=None)\n    parser.add_argument(\"--num_epochs\", type=int, default=1)\n    parser.add_argument(\"--learning_rate\", type=float, default=1e-6)\n    parser.add_argument(\"--beta\", type=float, default=0.1, help=\"the beta parameter for DPO loss\")\n    parser.add_argument(\"--lr_scheduler_type\", type=str, default=\"cosine\",\n                        help=\"transformers.trainer_utils.SchedulerType, including:\"\n                             \"linear, cosine, cosine_with_restarts, polynomial, constant,\"\n                             \"constant_with_warmup\")\n    parser.add_argument(\"--train_batch_size\", type=int, default=4)\n    parser.add_argument(\"--weight_decay\", type=float, default=0.1)\n    parser.add_argument(\"--warmup_ratio\", type=int, default=0.1)\n    parser.add_argument(\"--logging_steps\", type=int, default=100)\n    parser.add_argument(\"--save_strategy\", type=str, default=\"steps\",\n                        help='- `\"no\"`: No save is done during training.'\n                             '- `\"epoch\"`: Save is done at the end of each epoch.'\n                             '- `\"steps\"`: Save is done every `save_steps`.')\n    parser.add_argument(\"--save_steps\", type=int, default=1000)\n    parser.add_argument(\"--save_total_limit\", type=int, default=2)\n    parser.add_argument(\"--metric_for_best_model\", type=str, default=None)\n    parser.add_argument(\"--gradient_accumulation_steps\", type=int, default=8)\n    parser.add_argument(\"--gradient_checkpointing\", action=\"store_true\",\n                        help=\"If True, use gradient checkpointing to save memory at the expense of slower backward pass.\")\n    parser.add_argument(\"--deepspeed_config\", type=str, default=None)\n    parser.add_argument(\"--lora_rank\", type=int, default=0)\n    parser.add_argument(\"--lora_alpha\", type=int, default=1)\n    parser.add_argument(\"--lora_train_bias\", type=str, default=\"none\")\n    # eval\n    parser.add_argument(\"--do_eval\", action=\"store_true\")\n    parser.add_argument(\"--eval_filename\", type=str, default=None)\n    parser.add_argument(\"--eval_batch_size\", type=int, default=4)\n    parser.add_argument(\"--evaluation_strategy\", type=str, default=\"steps\",\n                        help='- `\"no\"`: No evaluation is done during training.'\n                             '- `\"steps\"`: Evaluation is done (and logged) every `eval_steps`.'\n                             '- `\"epoch\"`: Evaluation is done at the end of each epoch.')\n    parser.add_argument(\"--eval_steps\", type=int, default=100)\n    parser.add_argument(\"--eval_accumulation_steps\", type=int, default=1)\n    # pred\n    parser.add_argument(\"--do_pred\", action=\"store_true\")\n    parser.add_argument(\"--checkpoint\", type=str, default=None)\n    parser.add_argument(\"--test_filename\", type=str, default=None)\n    parser.add_argument(\"--output_filename\", type=str, default=None)\n    parser.add_argument(\"--do_sample\", action=\"store_true\")\n    parser.add_argument(\"--num_return_sequences\", type=int, default=1)\n    parser.add_argument(\"--top_k\", type=int, default=None)\n    parser.add_argument(\"--top_p\", type=float, default=None)\n    parser.add_argument(\"--temperature\", type=float, default=None)\n\n    args = parser.parse_args()\n    \n    return args\n\n\ndef main():\n    args = get_parser()\n    print_rank_0(f\"Parameters: {args}\")\n\n    set_seed(args.seed)\n\n    # load tokenizer and model\n    tokenizer, model, eos_token_id = load_tokenizer_and_model(args)\n\n    if args.checkpoint is not None:\n        load_checkpoint(args, model, strict=False)\n\n    print_rank_0(f\"Finished loading model and tokenizer\")\n\n    # Set up the datasets\n    if args.do_train:\n        train_dataset = DPODataset(args, os.path.join(args.data_dir, args.train_filename),\n                                   tokenizer)\n    else:\n        train_dataset = None\n    if args.do_eval:\n        dev_dataset = DPODataset(args, os.path.join(args.data_dir, args.eval_filename),\n                                 tokenizer)\n    else:\n        dev_dataset = None\n\n    if args.do_train:\n        if torch.cuda.is_available():\n            bf16 = torch.cuda.get_device_capability()[0] >= 8\n            fp16 = not bf16\n        else:\n            fp16 = False\n            bf16 = False\n        # training arguments\n        deepspeed_config = os.path.join(RESOURCE_PATH, \"config\", \"deepspeed\", args.deepspeed_config) if args.deepspeed_config is not None else None\n        training_args = TrainingArguments(\n            output_dir=args.output_dir,\n            no_cuda=not torch.cuda.is_available(),\n            seed=args.seed,\n            data_seed=args.seed,\n            local_rank=args.local_rank,\n            do_train=args.do_train,\n            num_train_epochs=args.num_epochs,\n            learning_rate=args.learning_rate,\n            lr_scheduler_type=args.lr_scheduler_type,\n            per_device_train_batch_size=args.train_batch_size,\n            gradient_accumulation_steps=args.gradient_accumulation_steps,\n            warmup_ratio=args.warmup_ratio,\n            weight_decay=args.weight_decay,\n            half_precision_backend=\"auto\",\n            fp16=fp16,\n            bf16=bf16,\n            optim=\"paged_adamw_8bit\",\n            # adam_beta1=0.9,\n            # adam_beta2=0.95,\n            save_strategy=args.save_strategy,\n            save_steps=args.save_steps,\n            save_total_limit=args.save_total_limit,\n            metric_for_best_model=args.metric_for_best_model,\n            greater_is_better=True,\n            logging_steps=args.logging_steps,\n            report_to=[\"tensorboard\"],\n            deepspeed=deepspeed_config,\n            gradient_checkpointing=args.gradient_checkpointing,\n            do_eval=args.do_eval,\n            evaluation_strategy=args.evaluation_strategy,\n            eval_steps=args.eval_steps,\n            eval_accumulation_steps=args.eval_accumulation_steps,\n            per_device_eval_batch_size=args.eval_batch_size,\n            # do_predict=args.do_pred,\n            # use_legacy_prediction_loop=args.do_pred,\n            remove_unused_columns=False,\n        )\n        print_rank_0(f\"Training Arguments: {training_args}\")\n\n        # load reference model or precomputed reference result\n        if args.output_filename is not None:\n            logps = torch.load(os.path.join(args.output_dir, args.output_filename))\n            ref_model = None\n        else:\n            logps = None\n            ref_args = copy.deepcopy(args)\n            ref_args.device_map = \"auto\"\n            if args.reference_model_name_or_path is not None:\n                ref_args.model_name_or_path = args.reference_model_name_or_path\n            else:\n                ref_args.bits = 4\n            _, ref_model, _ = load_tokenizer_and_model(ref_args)\n            ref_model.eval()\n\n        # Prepare the trainer and start training\n        trainer = DPOTrainer(\n            model=model,\n            ref_model=ref_model,\n            logps=logps,\n            args=training_args,\n            beta=args.beta,\n            train_dataset=train_dataset,\n            eval_dataset=dev_dataset,\n            tokenizer=tokenizer,\n            data_collator=default_data_collator,\n            # compute_metrics=compute_metrics,\n            # preprocess_logits_for_metrics=preprocess_logits_for_metrics,\n            label_pad_token_id=tokenizer.pad_token_id\n        )\n        # model.config.use_cache = False  # silence the warnings. Please re-enable for inference!\n        trainer.train()\n        trainer.save_model(args.output_dir)\n\n    elif args.do_eval:\n        # res = trainer.evaluate(eval_dataset=dev_dataset)\n        # print_rank_0(res)\n        pass\n\n    if args.do_pred:\n        def _get_batch_logps(\n                logits: torch.FloatTensor,\n                labels: torch.LongTensor,\n                average_log_prob: bool = False,\n        ) -> torch.FloatTensor:\n            \"\"\"Compute the log probabilities of the given labels under the given logits.\n\n            Args:\n                logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)\n                labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)\n                average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.\n\n            Returns:\n                A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.\n            \"\"\"\n            if logits.shape[:-1] != labels.shape:\n                raise ValueError(\"Logits (batch and sequence length dim) and labels must have the same shape.\")\n\n            labels = labels[:, 1:].clone()\n            logits = logits[:, :-1, :]\n            loss_mask = labels != tokenizer.pad_token_id\n\n            # dummy token; we'll ignore the losses on these tokens later\n            labels[labels == tokenizer.pad_token_id] = 0\n\n            per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)\n\n            if average_log_prob:\n                return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)\n            else:\n                return (per_token_logps * loss_mask).sum(-1)\n\n        model.eval()\n        device = f\"cuda:{args.local_rank}\" if torch.cuda.is_available() and args.device_map is not None else \"cpu\"\n\n        logps = dict()\n        for test_filename in args.test_filename.split(\",\"):\n            if \"train\" in test_filename:\n                mode = \"train\"\n            else:\n                mode = \"eval\"\n            logps[mode] = dict()\n            test_filename = os.path.join(args.data_dir, test_filename)\n            test_dataset = DPODataset(args, test_filename, tokenizer)\n            sampler = SequentialSampler(test_dataset)\n            test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, sampler=sampler)\n            with torch.no_grad():\n                for batch in tqdm(test_loader, desc=f\"Prediction on {mode}\"):\n                    indices = batch['index'].tolist()\n                    chosen_input_ids = batch['chosen_input_ids'].to(device)\n                    chosen_attention_mask = batch['chosen_attention_mask'].to(device) if 'chosen_attention_mask' in batch else None\n                    rejected_input_ids = batch['rejected_input_ids'].to(device)\n                    rejected_attention_mask = batch['rejected_attention_mask'].to(device) if 'rejected_attention_mask' in batch else None\n                    chosen_logits = model(chosen_input_ids, chosen_attention_mask).logits.detach().cpu().to(torch.float32)\n                    chosen_logps = _get_batch_logps(chosen_logits, batch[\"chosen_labels\"], average_log_prob=False)\n                    rejected_logits = model(rejected_input_ids, rejected_attention_mask).logits.detach().cpu().to(torch.float32)\n                    rejected_logps = _get_batch_logps(rejected_logits, batch[\"rejected_labels\"], average_log_prob=False)\n                    for index, chosen_logp, rejected_logp in zip(indices, chosen_logps, rejected_logps):\n                        logps[mode][index] = {\"chosen_logp\": chosen_logp, \"rejected_logp\": rejected_logp}\n\n        torch.save(logps, os.path.join(args.output_dir, args.output_filename))\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/train_reward.py",
    "content": "\nimport sys\nsys.path.insert(0, \"/mnt/sfevol775196/sunzeye273/Code/chatgpt\")\nsys.path.insert(0, \"/Users/zeyesun/Documents/Code/RLHF\")\nsys.path.insert(0, \"/mnt/pa002-28359-vol543625-private/Code/chatgpt\")\nimport os\nimport torch\nimport argparse\n\nfrom tqdm import tqdm\nfrom transformers import (\n    Trainer,\n    TrainingArguments,\n    default_data_collator,\n)\nfrom torch.utils.data import DataLoader, SequentialSampler\n\nfrom src.models.reward import RewardModel\nfrom src.utils import logger, RESOURCE_PATH, load_tokenizer_and_model, load_checkpoint\nfrom src.utils.file_utils import set_seed, print_rank_0\nfrom src.data.data import SFTDataset, PairwiseDataset, DataCollatorReward\n\n\ndef get_parser():\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument(\"--data_dir\", type=str, required=True)\n    parser.add_argument(\"--output_dir\", type=str, required=True)\n    parser.add_argument(\"--model_name_or_path\", type=str, required=True)\n    parser.add_argument(\"--tokenizer_path\", type=str, required=True)\n\n    parser.add_argument(\"--seed\", type=int, default=42)\n    parser.add_argument(\"--local_rank\", type=int, default=0)\n    parser.add_argument(\"--max_length\", type=int, default=1024)\n    parser.add_argument(\"--bits\", type=int, default=16)\n    parser.add_argument(\"--device_map\", type=str, default=None, help=\"device map to allocate model,\"\n                                                                     \"[None] means cpu\"\n                                                                     \"[0, 1, 2, ...], number means single-card\"\n                                                                     \"[auto, balanced, balanced_low_0] means multi-card\")\n    parser.add_argument(\"--low_cpu_mem_usage\", action=\"store_true\", help=\"whether to enable low cpu memory usage\"\n                                                                         \"when loading model\")\n    # train\n    parser.add_argument(\"--do_train\", action=\"store_true\")\n    parser.add_argument(\"--train_filename\", type=str, default=None)\n    parser.add_argument(\"--num_epochs\", type=int, default=1)\n    parser.add_argument(\"--learning_rate\", type=float, default=1e-6)\n    parser.add_argument(\"--lr_scheduler_type\", type=str, default=\"cosine\",\n                        help=\"transformers.trainer_utils.SchedulerType, including:\"\n                             \"linear, cosine, cosine_with_restarts, polynomial, constant,\"\n                             \"constant_with_warmup\")\n    parser.add_argument(\"--train_batch_size\", type=int, default=4)\n    parser.add_argument(\"--freeze_ratio\", type=float, default=0.0, help=\"ratio of layers frozen for reward training\")\n    parser.add_argument(\"--weight_decay\", type=float, default=0.01)\n    parser.add_argument(\"--warmup_ratio\", type=int, default=0.1)\n    parser.add_argument(\"--logging_steps\", type=int, default=100)\n    parser.add_argument(\"--save_strategy\", type=str, default=\"steps\",\n                        help='- `\"no\"`: No save is done during training.'\n                             '- `\"epoch\"`: Save is done at the end of each epoch.'\n                             '- `\"steps\"`: Save is done every `save_steps`.')\n    parser.add_argument(\"--save_steps\", type=int, default=1000)\n    parser.add_argument(\"--save_total_limit\", type=int, default=2)\n    parser.add_argument(\"--metric_for_best_model\", type=str, default=None)\n    parser.add_argument(\"--gradient_accumulation_steps\", type=int, default=8)\n    parser.add_argument(\"--max_grad_norm\", type=float, default=1.0)\n    parser.add_argument(\"--gradient_checkpointing\", action=\"store_true\",\n                        help=\"If True, use gradient checkpointing to save memory at the expense of slower backward pass.\")\n    parser.add_argument(\"--deepspeed_config\", type=str, default=None)\n    parser.add_argument(\"--lora_rank\", type=int, default=0)\n    parser.add_argument(\"--lora_alpha\", type=int, default=1)\n    parser.add_argument(\"--lora_train_bias\", type=str, default=\"none\")\n    # eval\n    parser.add_argument(\"--do_eval\", action=\"store_true\")\n    parser.add_argument(\"--eval_filename\", type=str, default=None)\n    parser.add_argument(\"--eval_batch_size\", type=int, default=4)\n    parser.add_argument(\"--evaluation_strategy\", type=str, default=\"steps\",\n                        help='- `\"no\"`: No evaluation is done during training.'\n                             '- `\"steps\"`: Evaluation is done (and logged) every `eval_steps`.'\n                             '- `\"epoch\"`: Evaluation is done at the end of each epoch.')\n    parser.add_argument(\"--eval_steps\", type=int, default=100)\n    parser.add_argument(\"--eval_accumulation_steps\", type=int, default=1)\n    # pred\n    parser.add_argument(\"--do_pred\", action=\"store_true\")\n    parser.add_argument(\"--checkpoint\", type=str, default=None)\n    parser.add_argument(\"--test_filename\", type=str, default=None)\n    parser.add_argument(\"--output_filename\", type=str, default=None)\n\n    args = parser.parse_args()\n\n    return args\n\n\ndef main():\n    args = get_parser()\n    print_rank_0(f\"Parameters: {args}\")\n\n    set_seed(args.seed)\n\n    # load model and tokenizer\n    tokenizer, model, eos_token_id = load_tokenizer_and_model(args)\n\n    # Initialize the reward model from the (supervised) fine-tuned SFT model\n    if \"pangu\" in args.model_name_or_path.lower():\n        reward_model = RewardModel(model.config, model.transformer, tokenizer)\n    elif \"chatglm\" in args.model_name_or_path.lower():\n        reward_model = RewardModel(model.config, model.transformer, tokenizer)\n    elif \"glm\" in args.model_name_or_path.lower():\n        reward_model = RewardModel(model.config, model.glm, tokenizer)\n    else:\n        reward_model = RewardModel(model.config, model.model, tokenizer)\n    assert model.config.pad_token_id == tokenizer.pad_token_id\n\n    if args.checkpoint is not None:\n        load_checkpoint(args, reward_model, strict=False)\n\n    if args.device_map is not None and args.bits not in [4, 8]:\n        reward_model = reward_model.half()\n\n    print_rank_0(f\"Finished loading model and tokenizer\")\n\n    # Set up the datasets\n    if args.do_train:\n        train_dataset = PairwiseDataset(args, os.path.join(args.data_dir, args.train_filename),\n                                        tokenizer)\n    else:\n        train_dataset = None\n    if args.do_eval:\n        val_dataset = PairwiseDataset(args, os.path.join(args.data_dir, args.eval_filename),\n                                      tokenizer)\n    else:\n        val_dataset = None\n    if args.do_pred:\n        test_dataset = SFTDataset(args, os.path.join(args.data_dir, args.test_filename),\n                                  tokenizer, concat_samples=False)\n    else:\n        test_dataset = None\n\n    if args.do_train or args.do_eval:\n        # training arguments\n        deepspeed_config = os.path.join(RESOURCE_PATH, \"config\", \"deepspeed\", args.deepspeed_config) if args.deepspeed_config is not None else None\n        if torch.cuda.is_available():\n            bf16 = torch.cuda.get_device_capability()[0] >= 8\n            fp16 = False if bf16 else True\n        else:\n            fp16 = False\n            bf16 = False\n        training_args = TrainingArguments(\n            output_dir=args.output_dir,\n            no_cuda=not torch.cuda.is_available(),\n            seed=args.seed,\n            data_seed=args.seed,\n            local_rank=args.local_rank,\n            do_train=args.do_train,\n            num_train_epochs=args.num_epochs,\n            learning_rate=args.learning_rate,\n            lr_scheduler_type=args.lr_scheduler_type,\n            per_device_train_batch_size=args.train_batch_size,\n            max_grad_norm=args.max_grad_norm,\n            gradient_accumulation_steps=args.gradient_accumulation_steps,\n            warmup_ratio=args.warmup_ratio,\n            weight_decay=args.weight_decay,\n            half_precision_backend=\"auto\",\n            fp16=fp16,\n            bf16=bf16,\n            save_strategy=args.save_strategy,\n            save_steps=args.save_steps,\n            save_total_limit=args.save_total_limit,\n            metric_for_best_model=args.metric_for_best_model,\n            greater_is_better=True,\n            logging_steps=args.logging_steps,\n            report_to=[\"tensorboard\"],\n            deepspeed=deepspeed_config,\n            gradient_checkpointing=args.gradient_checkpointing,\n            do_eval=args.do_eval,\n            evaluation_strategy=args.evaluation_strategy,\n            eval_steps=args.eval_steps,\n            eval_accumulation_steps=args.eval_accumulation_steps,\n            per_device_eval_batch_size=args.eval_batch_size,\n            label_names=[\"labels\"],\n            # do_predict=args.do_pred,\n            # use_legacy_prediction_loop=args.do_pred,\n        )\n        print_rank_0(f\"Training Arguments: {training_args}\")\n\n        def compute_metrics(eval_preds):\n            chosen_end_scores = eval_preds.predictions[1]  # chosen scores\n            rejected_end_scores = eval_preds.predictions[3]  # rejected scores\n            result = {}\n            acc = sum(chosen_end_scores > rejected_end_scores) / len(rejected_end_scores)\n            result[\"accuracy\"] = acc\n\n            return result\n\n        # Prepare the trainer and start training\n        trainer = Trainer(\n            model=reward_model,\n            args=training_args,\n            train_dataset=train_dataset,\n            compute_metrics=compute_metrics,\n            eval_dataset=val_dataset,\n        )\n\n        if args.do_train:\n            trainer.train()\n            trainer.save_model(args.output_dir)\n        elif args.do_eval:\n            eval_result = trainer.evaluate(eval_dataset=val_dataset)\n            print_rank_0(eval_result)\n\n    if args.do_pred:\n        reward_model.eval()\n        device = f\"cuda:{args.local_rank}\" if torch.cuda.is_available() and args.device_map is not None else \"cpu\"\n        sampler = SequentialSampler(test_dataset)\n        test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, sampler=sampler)\n        rewards = []\n        with torch.no_grad():\n            for batch in tqdm(test_loader, desc=\"Prediction\"):\n                chosen_input_ids = batch['input_ids'].to(device)\n                chosen_attention_mask = batch['attention_mask'].to(device) if 'attention_mask' in batch else None\n                chosen_position_ids = batch['position_ids'].to(device) if 'position_ids' in batch else None\n                output = reward_model(chosen_input_ids, chosen_attention_mask, chosen_position_ids)\n                rewards.extend(output['chosen_reward'].cpu().detach().tolist())\n        # save result into file\n        with open(os.path.join(args.output_dir, args.output_filename), \"w\", encoding=\"utf-8\") as w:\n            w.write(\"\\t\".join((\"prompt\", \"answer\", \"score\"))+\"\\n\")\n            for item, reward in zip(test_dataset.post_list, rewards):\n                w.write(\"\\t\".join((item[\"prompt\"], item[\"label\"], str(reward))) + \"\\n\")\n        print_rank_0(f\"Finished prediction and saving into {args.output_filename}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/train_rlhf.py",
    "content": "\nimport sys\nsys.path.insert(0, \"/root/autodl-tmp/Code/RLHF\")\nsys.path.insert(0, \"/mnt/sfevol775196/sunzeye273/Code/chatgpt\")\n# sys.path.insert(0, \"/mnt/share-pa002-vol682688-prd/sunzeye273/Code/chatgpt\")\nsys.path.insert(0, \"/mnt/pa002-28359-vol543625-private/Code/chatgpt\")\n\nimport os\nimport argparse\nimport torch\nimport random\nimport copy\nimport deepspeed\nimport numpy as np\n\nfrom transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, default_data_collator\nfrom torch.utils.data import RandomSampler, DistributedSampler, DataLoader\n\nfrom src.models.rlhf_engine import DeepSpeedRLHFEngine\nfrom src.models.trainer import DeepSpeedPPOTrainer, DeepSpeedPPOPTXTrainer\nfrom src.utils.file_utils import set_seed, print_rank_0, print_gpu_utilization_torch, print_gpu_utilization\nfrom src.data.data import SFTDataset, RLHFDataset, PPODataset\nfrom src.utils.modeling_utils import get_all_reduce_mean, save_hf_format, moving_average, save_zero_three_model\n\n\ndef get_parser():\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument(\"--data_dir\", type=str, required=True)\n    parser.add_argument(\"--output_dir\", type=str, required=True)\n    parser.add_argument(\"--tokenizer_path\", type=str, required=True)\n    parser.add_argument(\"--actor_model_path\", type=str, required=True)\n    parser.add_argument(\"--critic_model_path\", type=str, required=True)\n    parser.add_argument(\"--critic_checkpoint\", type=str, required=True)\n\n    parser.add_argument(\"--seed\", type=int, default=42)\n    parser.add_argument(\"--local_rank\", type=int, default=0)\n    parser.add_argument(\"--max_length\", type=int, default=512,\n                        help=\"total max sequence length = max prompt length + mas generation/answer length\")\n    parser.add_argument(\"--max_gen_length\", type=int, default=256,\n                        help=\"max generation/answer length\")\n    # train\n    parser.add_argument(\"--do_train\", action=\"store_true\")\n    parser.add_argument(\"--train_filename\", type=str, default=None)\n    parser.add_argument(\"--pretrain_filename\", type=str, default=None,\n                        help=\"pretraining dataset (for PPO-ptx)\")\n    parser.add_argument(\"--num_epochs\", type=int, default=1,\n                        help=\"Total number of training epochs to perform.\")\n    parser.add_argument(\"--ppo_epochs\", type=int, default=1,\n                        help=\"Number of epochs to perform ppo training for each experience\")\n    parser.add_argument(\"--actor_learning_rate\", type=float, default=1e-6)\n    parser.add_argument(\"--critic_learning_rate\", type=float, default=1e-6)\n    parser.add_argument(\"--lr_scheduler_type\", type=str, default=\"linear\",\n                        help=\"transformers.trainer_utils.SchedulerType, including:\"\n                             \"linear, cosine, cosine_with_restarts, polynomial, constant,\"\n                             \"constant_with_warmup\")\n    parser.add_argument(\"--train_batch_size\", type=int, default=4)\n    parser.add_argument(\"--ppo_train_batch_size\", type=int, default=4,\n                        help=\"PPO training mini batch size (per device)\")\n    parser.add_argument(\"--ppo_batch_numbers\", type=int, default=1,\n                        help=\"number of batches for PPO training\")\n    parser.add_argument(\"--actor_weight_decay\", type=float, default=0.1)\n    parser.add_argument(\"--critic_weight_decay\", type=float, default=0.1)\n    parser.add_argument(\"--warmup_ratio\", type=int, default=0.1)\n    parser.add_argument(\"--logging_steps\", type=int, default=100)\n    parser.add_argument('--disable_actor_dropout', action='store_true',\n                        help='Disable the dropout of the actor model.')\n    parser.add_argument('--disable_critic_dropout', action='store_true',\n                        help='Disable the dropout of the critic model.')\n    parser.add_argument(\"--pretrain_coef\", type=float, default=10.0,\n                        help=\"coefficient of pretraining loss in ppo-ptx objective function\")\n    parser.add_argument(\"--save_total_limit\", type=int, default=2)\n    parser.add_argument(\"--gradient_accumulation_steps\", type=int, default=8)\n    parser.add_argument(\"--do_sample\", action=\"store_true\")\n    parser.add_argument(\"--num_return_sequences\", type=int, default=1)\n    parser.add_argument(\"--top_k\", type=int, default=50)\n    parser.add_argument(\"--top_p\", type=float, default=0.8)\n    parser.add_argument(\"--temperature\", type=float, default=1.0)\n    parser.add_argument(\"--kl_coefficient\", type=float, default=0.02,\n                        help=\"Coefficient of KL divergence, used in computing modified reward\")\n    parser.add_argument(\"--clip_reward_value\", type=float, default=5.0)\n    parser.add_argument(\"--clip_range\", type=float, default=0.2,\n                        help=\"Clip range of policy, used in computing clipped policy loss\")\n    parser.add_argument(\"--clip_range_value\", type=float, default=0.2,\n                        help=\"Clip range of value function, used in computing clipped value function loss\")\n    parser.add_argument(\"--gamma\", type=float, default=1.0,\n                        help=\"first decaying factor, used in computing advantages\")\n    parser.add_argument(\"--lambda_\", type=float, default=0.95,\n                        help=\"second decaying factor, used in computing advantages\")\n    parser.add_argument('--enable_joint_optimization', action='store_true',\n                        help=\"Enable joint optimization of policy and reward\")\n    # deepspeed\n    parser.add_argument('--enable_hybrid_engine', action='store_true',\n                        help=\"Enable hybrid engine for actor model to optimize both inference and training through DeepSpeed.\")\n    parser.add_argument('--actor_zero_stage', type=int, default=0,\n                        help='ZeRO optimization stage for Actor model (and clones).')\n    parser.add_argument('--critic_zero_stage', type=int, default=0,\n                        help='ZeRO optimization stage for Critic model (and reward).')\n    parser.add_argument('--offload', action='store_true', help='Enable ZeRO Offload techniques.')\n    parser.add_argument('--offload_reference_model', action='store_true',\n                        help='Enable ZeRO Offload techniques for reference model')\n    parser.add_argument(\"--actor_gradient_checkpointing\", action=\"store_true\",\n                        help=\"whether to use gradient checkpointing for actor model\")\n    parser.add_argument(\"--critic_gradient_checkpointing\", action=\"store_true\",\n                        help=\"whether to use gradient checkpointing for critic model\")\n    parser.add_argument(\"--unpin_actor_parameters\", action='store_true',\n                        help=\"Unpin actor's parameters during generation. This makes generation slower but requires less memory.\")\n    parser.add_argument(\"--release_inference_cache\", action='store_true',\n                        help=\"Release the memory cache used for inference. This makes generation preparation slower but might increase e2e throughput by using larger batch size.\")\n    parser.add_argument(\"--inference_tp_size\", type=int, default=1,\n                        help=\"Tensor-parallelism degree used for the inference-optimization. Please note hybrid-engine need to be enabled when using this feature.\")\n    parser.add_argument(\"--tp_gather_partition_size\", type=int, default=8,\n                        help=\"Granularity to bring in layers for TP sharding inside the hybrid engine. Please note hybrid-engine and tp_inference_size > 1 need to be true when using this feature.\")\n    # parser.add_argument(\"--num_layers_unfrozen\", type=int, default=-1, help=\"Number of layers to unfreeze for fine-tuning\")\n    parser.add_argument('--enable_ema', action='store_true', help='Enable EMA checkpoint for the model.')\n    # lora\n    parser.add_argument(\"--actor_lora_rank\", type=int, default=0)\n    parser.add_argument(\"--actor_lora_alpha\", type=int, default=16)\n    parser.add_argument(\"--critic_lora_rank\", type=int, default=0)\n    parser.add_argument(\"--critic_lora_alpha\", type=int, default=16)\n    parser.add_argument(\"--lora_alpha\", type=int, default=1)\n    parser.add_argument(\"--lora_train_bias\", type=str, default=\"none\")\n    # eval\n    parser.add_argument(\"--do_eval\", action=\"store_true\")\n    parser.add_argument(\"--eval_filename\", type=str, default=None)\n    parser.add_argument(\"--eval_batch_size\", type=int, default=4)\n    parser.add_argument(\"--evaluation_strategy\", type=str, default=\"epoch\",\n                        help='- `\"no\"`: No evaluation is done during training.'\n                             '- `\"steps\"`: Evaluation is done (and logged) every `eval_steps`.'\n                             '- `\"epoch\"`: Evaluation is done at the end of each epoch.')\n    parser.add_argument(\"--eval_steps\", type=int, default=None)\n    parser.add_argument(\"--eval_accumulation_steps\", type=int, default=1)\n    # pred\n    parser.add_argument(\"--do_pred\", action=\"store_true\")\n    parser.add_argument(\"--checkpoint\", type=str, default=None)\n    parser.add_argument(\"--test_filename\", type=str, default=None)\n    parser.add_argument(\"--output_filename\", type=str, default=None)\n\n    args = parser.parse_args()\n\n    return args\n\n\ndef create_datasets(args, tokenizer_padding_from_left, ppo_ptx_enabled, tokenizer_padding_from_right):\n    train_dataset = RLHFDataset(args, os.path.join(args.data_dir, args.train_filename),\n                                tokenizer_padding_from_left)\n    iters_prompt = len(train_dataset) // args.train_batch_size\n\n    if ppo_ptx_enabled:\n        pretrain_dataset = SFTDataset(args, os.path.join(args.data_dir, args.pretrain_filename),\n                                      tokenizer_padding_from_right, concat_samples=False)\n        iters_pretrain = len(pretrain_dataset) // args.train_batch_size\n    else:\n        pretrain_dataset = None\n        iters_pretrain = np.inf\n\n    num_update_steps_per_epoch = min(iters_prompt, iters_pretrain) * \\\n                                 (args.train_batch_size / args.ppo_train_batch_size) * \\\n                                 args.ppo_epochs / args.gradient_accumulation_steps\n    num_total_iters = int(args.num_epochs * num_update_steps_per_epoch)\n\n    return train_dataset, pretrain_dataset, num_total_iters\n\n\ndef create_dataloader(args, train_dataset, pretrain_dataset=None):\n\n    # DataLoaders creation:\n    # data_collator = DataCollatorRLHF(args.max_length, pad_token_id)\n    if args.local_rank == -1:\n        prompt_train_sampler = RandomSampler(train_dataset)\n        if pretrain_dataset is not None:\n            pretrain_sampler = RandomSampler(pretrain_dataset)\n    else:\n        prompt_train_sampler = DistributedSampler(train_dataset)\n        if pretrain_dataset is not None:\n            pretrain_sampler = DistributedSampler(pretrain_dataset)\n\n    # prompt_train_dataloader is only used for generate_experience() where batch size is fixed at 1\n    prompt_train_dataloader = DataLoader(\n        train_dataset,\n        # collate_fn=data_collator,\n        sampler=prompt_train_sampler,\n        batch_size=1)\n        # batch_size=args.train_batch_size)\n    if pretrain_dataset is not None:\n        pretrain_dataloader = DataLoader(\n            pretrain_dataset,\n            # collate_fn=default_data_collator,\n            sampler=pretrain_sampler,\n            batch_size=args.train_batch_size)\n    else:\n        pretrain_dataloader = [None] * len(\n            prompt_train_dataloader)\n\n    # num_update_steps_per_epoch = min(len(prompt_train_dataloader), len(pretrain_dataloader)) * \\\n    #                              (args.train_batch_size / args.ppo_train_batch_size) * \\\n    #                              args.ppo_epochs / args.gradient_accumulation_steps\n    # num_total_iters = int(args.num_epochs * num_update_steps_per_epoch)\n\n    return prompt_train_dataloader, pretrain_dataloader#, num_total_iters\n\n\ndef main():\n    args = get_parser()\n\n    assert not args.enable_hybrid_engine, \"DeepSpeed currently does not support Pangu-based or GLM-based model in hybrid engine\"\n\n    if args.local_rank == -1:\n        device = torch.device(\"cuda\")\n    else:\n        torch.cuda.set_device(args.local_rank)\n        device = torch.device(\"cuda\", args.local_rank)\n        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs\n        # deepspeed.init_distributed()\n\n    # args.global_rank = torch.distributed.get_rank()\n    print_rank_0(f\"Parameters: {args}\")\n\n    set_seed(args.seed)\n    # torch.distributed.barrier()\n\n    # Set PPO-ptx\n    ppo_ptx_enabled = args.pretrain_filename is not None\n    if ppo_ptx_enabled:\n        args.gradient_accumulation_steps_actor = args.gradient_accumulation_steps * 2\n    else:\n        args.gradient_accumulation_steps_actor = args.gradient_accumulation_steps\n    # n_gpus = torch.distributed.get_world_size()\n    n_gpus = torch.cuda.device_count()\n    args.global_train_batch_size_actor = args.ppo_train_batch_size * args.gradient_accumulation_steps_actor * n_gpus\n    args.global_train_batch_size_critic = args.ppo_train_batch_size * args.gradient_accumulation_steps * n_gpus\n\n    # load tokenizer\n    tokenizer_padding_from_right = AutoTokenizer.from_pretrained(args.tokenizer_path, use_cache=False, trust_remote_code=True)\n    if \"chatglm2\" in args.tokenizer_path:\n        tokenizer_padding_from_right.eop_token_id = tokenizer_padding_from_right.get_command(\"eop\")\n    tokenizer_padding_from_left = copy.deepcopy(tokenizer_padding_from_right)\n    # tokenizer.pad_token = tokenizer.eos_token\n    tokenizer_padding_from_left.padding_side = \"left\" # PS: padding side slightly affect output of sft generation and reward model result\n    args.max_prompt_length = args.max_length - args.max_gen_length\n\n    print_gpu_utilization(\"before create deepspeed rlhf engine\", args.local_rank)\n    print_gpu_utilization_torch(\"before create deepspeed rlhf engine\", args.local_rank)\n\n    if args.do_train:\n        # load data and create dataset\n        prompt_train_dataset, pretrain_dataset, num_total_iters = create_datasets(args, tokenizer_padding_from_left,\n                                                                                  ppo_ptx_enabled, tokenizer_padding_from_right)\n        args.warmup_steps = int(num_total_iters * args.warmup_ratio)\n\n        # load rlhf engine\n        rlhf_engine = DeepSpeedRLHFEngine(\n            actor_model_name_or_path=args.actor_model_path,\n            critic_model_name_or_path=args.critic_model_path,\n            tokenizer=tokenizer_padding_from_right,\n            num_total_iters=num_total_iters,\n            args=args)\n\n        # create dataloader [need to be called after rlhf engine initialization because\n        # DistributedSampler can only be called after deepspeed.initialize() is called]\n        prompt_train_dataloader, pretrain_dataloader = create_dataloader(args, prompt_train_dataset,\n                                                                         pretrain_dataset)\n\n        # create deepspeed ppo trainer\n        print_gpu_utilization(\"after create deepspeed rlhf engine\", args.local_rank)\n        print_gpu_utilization_torch(\"after create deepspeed rlhf engine\", args.local_rank)\n        ppo_trainer = DeepSpeedPPOPTXTrainer if ppo_ptx_enabled else DeepSpeedPPOTrainer\n        trainer = ppo_trainer(rlhf_engine, args)\n        print_gpu_utilization(\"after create deepspeed trainer\", args.local_rank)\n        print_gpu_utilization_torch(\"after create deepspeed trainer\", args.local_rank)\n\n        # create ppo experience dataset\n        exp_mini_dataset = PPODataset(args.ppo_batch_numbers,\n                                      args.ppo_train_batch_size)\n        pretrain_mini_dataset = PPODataset(args.ppo_batch_numbers,\n                                           args.ppo_train_batch_size)\n\n        print_rank_0(\"Start training\")\n        for epoch in range(args.num_epochs):\n            print_rank_0(f\"Beginning of Epoch {epoch+1}/{args.num_epochs}, \"\n                         f\"Total Generation Batches {min(len(prompt_train_dataloader), len(pretrain_dataloader))}\")\n            prompt_iter = iter(prompt_train_dataloader)\n            pretrain_iter = iter(pretrain_dataloader)\n            step = 0\n            # for step, (batch_prompt, batch_pretrain) in enumerate(zip(prompt_train_dataloader, pretrain_dataloader)):\n            while True:\n                # generate sequence: generate only one sequence at a time, aggregate to form a batch\n                answer_start_indices = []\n                output_sequences = dict()\n                for _ in range(args.train_batch_size):\n                    try:\n                        batch_prompt = next(prompt_iter)\n                        batch_prompt = {k: v.to(device) for k, v in batch_prompt.items()}\n                        outputs, prompt_length = trainer.generate_sequence(batch_prompt)\n                        answer_start_indices.append(prompt_length-1)\n                        for key, val in outputs.items():\n                            if key not in output_sequences:\n                                output_sequences[key] = []\n                            output_sequences[key].append(val[0])\n                    except StopIteration:\n                        break\n                if len(output_sequences) > 0:\n                    output_sequences = {key: torch.stack(val).to(device) for key, val in output_sequences.items()}\n                    output_experiences = trainer.generate_experience(output_sequences, answer_start_indices)\n                    exp_dataset = exp_mini_dataset.add(output_experiences)\n                else:\n                    exp_dataset = None\n\n                try:\n                    batch_pretrain = next(pretrain_iter)\n                    if batch_pretrain is not None:\n                        batch_pretrain = {k: v.to(device) for k, v in batch_pretrain.items()}\n                        pretrain_dataset = pretrain_mini_dataset.add(batch_pretrain)\n                    else:\n                        pretrain_dataset = pretrain_mini_dataset.add([[None] * args.train_batch_size])\n                except StopIteration:\n                    pass\n\n                if exp_dataset is not None:\n                    inner_iter = 0\n                    critic_loss, actor_loss, pretrain_loss = 0, 0, 0\n                    average_reward = 0\n\n                    if args.actor_gradient_checkpointing:\n                        rlhf_engine.actor.gradient_checkpointing_enable()\n\n                    for ppo_ep in range(args.ppo_epochs):\n                        for i, (exp_data, pretrain_data) in enumerate(\n                                zip(exp_dataset, pretrain_dataset)):\n                            actor_loss, critic_loss = trainer.train_rlhf(exp_data)\n                            critic_loss += actor_loss.item()\n                            actor_loss += critic_loss.item()\n                            average_reward += exp_data[\"rewards\"].mean()\n\n                            if ppo_ptx_enabled:\n                                pretrain_loss = trainer.train_unsupervised(pretrain_data, args.pretrain_coef)\n                                pretrain_loss += pretrain_loss.item()\n\n                            inner_iter += 1\n                            if args.enable_ema:\n                                moving_average(rlhf_engine.actor,\n                                               rlhf_engine.actor_ema,\n                                               zero_stage=args.actor_zero_stage)\n\n                        random.shuffle(exp_dataset)\n                        random.shuffle(pretrain_dataset)\n\n                    print_rank_0(f'epoch: {epoch}, step: {step}, ppo_ep: {ppo_ep+1}, act_loss: {actor_loss/inner_iter},'\n                                 f'cri_loss: {critic_loss/inner_iter}, pretrain_loss: {pretrain_loss/inner_iter}')\n                    average_reward = get_all_reduce_mean(average_reward).item()\n                    print_rank_0(f\"average reward score: {average_reward/inner_iter}\")\n\n                if args.actor_gradient_checkpointing:\n                    rlhf_engine.actor.gradient_checkpointing_disable()\n\n                step += 1\n\n        print_rank_0('saving model ...')\n\n        # if args.actor_lora_rank > 0:\n        #     rlhf_engine.actor = convert_lora_to_linear_layer(rlhf_engine.actor)\n        #     if args.enable_ema:\n        #         rlhf_engine.actor_ema = convert_lora_to_linear_layer(rlhf_engine.actor_ema)\n        # if args.critic_lora_rank > 0:\n        #     rlhf_engine.critic = convert_lora_to_linear_layer(rlhf_engine.critic)\n\n        if args.local_rank == 0:\n            save_hf_format(rlhf_engine.actor, tokenizer_padding_from_right, args, sub_folder='actor')\n            save_hf_format(rlhf_engine.critic, tokenizer_padding_from_right, args, sub_folder='critic')\n            if args.enable_ema:\n                save_hf_format(rlhf_engine.actor_ema, tokenizer_padding_from_right, args, sub_folder='actor_ema')\n\n        if args.actor_zero_stage == 3:\n            save_zero_three_model(rlhf_engine.actor, global_rank=args.local_rank,\n                                  save_dir=os.path.join(args.output_dir, 'actor'),\n                                  zero_stage=args.actor_zero_stage)\n            if args.enable_ema:\n                save_zero_three_model(rlhf_engine.actor_ema, global_rank=args.local_rank,\n                                      save_dir=os.path.join(args.output_dir, 'actor_ema'),\n                                      zero_stage=args.actor_zero_stage)\n        if args.critic_zero_stage == 3:\n            save_zero_three_model(rlhf_engine.critic, global_rank=args.local_rank,\n                                  save_dir=os.path.join(args.output_dir, 'critic'),\n                                  zero_stage=args.critic_zero_stage)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/train_rlhf_trlx.py",
    "content": "\nimport sys\nsys.path.insert(0, \"/root/autodl-tmp/Code/RLHF\")\nsys.path.insert(0, \"/mnt/sfevol775196/sunzeye273/Code/chatgpt\")\n# sys.path.insert(0, \"/mnt/share-pa002-vol682688-prd/sunzeye273/Code/chatgpt\")\nsys.path.insert(0, \"/mnt/pa002-28359-vol543625-private/Code/chatgpt\")\n\nimport os\nimport argparse\nimport torch\nimport glob\n\nfrom typing import Callable, Dict, Iterable, List, Optional, Tuple, Any\nfrom transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM\n\nfrom src.utils import logger, RESOURCE_PATH\nfrom src.utils.config import TRLConfig, default_ilql_config, default_ppo_config, default_sft_config\nfrom src.models.reward import RewardModel\nfrom src.utils.file_utils import set_seed\nfrom src.data.data import RLHFDataset\nfrom src.utils.loading import get_pipeline, get_trainer\n\n\ndef get_parser():\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument(\"--data_dir\", type=str, required=True)\n    parser.add_argument(\"--output_dir\", type=str, required=True)\n    parser.add_argument(\"--tokenizer_path\", type=str, required=True)\n    parser.add_argument(\"--sft_model_path\", type=str, required=True)\n    parser.add_argument(\"--reward_model_path\", type=str, required=True)\n    parser.add_argument(\"--reward_checkpoint\", type=str, required=True)\n\n    parser.add_argument(\"--seed\", type=int, default=42)\n    parser.add_argument(\"--local_rank\", type=int, default=0)\n    parser.add_argument(\"--max_length\", type=int, default=1024)\n    # train\n    parser.add_argument(\"--do_train\", action=\"store_true\")\n    parser.add_argument(\"--train_filename\", type=str, default=None)\n    parser.add_argument(\"--num_epochs\", type=int, default=1)\n    parser.add_argument(\"--learning_rate\", type=float, default=1e-6)\n    parser.add_argument(\"--lr_scheduler_type\", type=str, default=\"linear\",\n                        help=\"transformers.trainer_utils.SchedulerType, including:\"\n                             \"linear, cosine, cosine_with_restarts, polynomial, constant,\"\n                             \"constant_with_warmup\")\n    parser.add_argument(\"--train_batch_size\", type=int, default=4)\n    parser.add_argument(\"--weight_decay\", type=float, default=0.01)\n    parser.add_argument(\"--warmup_ratio\", type=int, default=0.1)\n    parser.add_argument(\"--logging_steps\", type=int, default=100)\n    parser.add_argument(\"--save_strategy\", type=str, default=\"epoch\",\n                        help='- `\"no\"`: No save is done during training.'\n                             '- `\"epoch\"`: Save is done at the end of each epoch.'\n                             '- `\"steps\"`: Save is done every `save_steps`.')\n    parser.add_argument(\"--save_steps\", type=int, default=None)\n    parser.add_argument(\"--save_total_limit\", type=int, default=2)\n    parser.add_argument(\"--gradient_accumulation_steps\", type=int, default=8)\n    parser.add_argument(\"--gradient_checkpointing\", action=\"store_true\",\n                        help=\"If True, use gradient checkpointing to save memory at the expense of slower backward pass.\")\n    parser.add_argument(\"--deepspeed_config\", type=str, default=None)\n    parser.add_argument(\"--ppo_config\", type=str, default=None)\n    parser.add_argument(\"--num_layers_unfrozen\", type=int, default=-1, help=\"Number of layers to unfreeze for fine-tuning\")\n    parser.add_argument(\"--lora_rank\", type=int, default=0)\n    parser.add_argument(\"--lora_alpha\", type=int, default=1)\n    parser.add_argument(\"--lora_train_bias\", type=str, default=\"none\")\n    # eval\n    parser.add_argument(\"--do_eval\", action=\"store_true\")\n    parser.add_argument(\"--eval_filename\", type=str, default=None)\n    parser.add_argument(\"--eval_batch_size\", type=int, default=4)\n    parser.add_argument(\"--evaluation_strategy\", type=str, default=\"epoch\",\n                        help='- `\"no\"`: No evaluation is done during training.'\n                             '- `\"steps\"`: Evaluation is done (and logged) every `eval_steps`.'\n                             '- `\"epoch\"`: Evaluation is done at the end of each epoch.')\n    parser.add_argument(\"--eval_steps\", type=int, default=None)\n    parser.add_argument(\"--eval_accumulation_steps\", type=int, default=1)\n    # pred\n    parser.add_argument(\"--do_pred\", action=\"store_true\")\n    parser.add_argument(\"--checkpoint\", type=str, default=None)\n    parser.add_argument(\"--test_filename\", type=str, default=None)\n    parser.add_argument(\"--output_filename\", type=str, default=None)\n\n    args = parser.parse_args()\n\n    return args\n\n\ndef train(model_path: Optional[str] = None,\n          reward_fn: Optional[Callable[[List[Any], Any], torch.Tensor]] = None,\n          dataset: Optional[Iterable[Tuple[str, float]]] = None,\n          samples: Optional[List[str]] = None,\n          rewards: Optional[List[float]] = None,\n          prompts: Optional[List[str]] = None,\n          eval_prompts: Optional[List[str]] = None,\n          metric_fn: Optional[Callable[[List[str], List[str], List[str]], Dict[str, List[float]]]] = None,\n          config: Optional[TRLConfig] = None,\n          stop_sequences: Optional[List[str]] = [],):\n    if config is None:\n        logger.warn(\n            \"Passing the `config` argument implicitly is depreciated, use or\"\n            \"adapt some from default configs instead\"\n        )\n        if reward_fn:\n            config = default_ppo_config()\n        elif rewards:\n            config = default_ilql_config()\n        else:\n            config = default_sft_config()\n\n    set_seed(config.train.seed)\n\n    if dataset:\n        logger.warn(\"the `dataset` argument is being depreciated, split it into `samples` and `rewards` instead\")\n        samples, rewards = dataset\n\n    if model_path:\n        config.model.model_path = model_path\n\n    trainer = get_trainer(config.train.trainer)(\n        config=config,\n        reward_fn=reward_fn,\n        metric_fn=metric_fn,\n        stop_sequences=stop_sequences,\n        **config.train.trainer_kwargs,\n    )\n\n    batch_size = config.train.batch_size * int(os.environ.get(\"WORLD_SIZE\", 1))\n    max_prompt_length = config.train.seq_length - config.method.gen_kwargs[\"max_new_tokens\"]\n\n    # Online training against a reward function (e.g. PPO)\n    if reward_fn:\n        prompts = prompts or [trainer.tokenizer.bos_token] * batch_size\n\n        if eval_prompts is None:\n            eval_prompts = prompts[:batch_size]\n\n        pipeline = get_pipeline(config.train.pipeline)(prompts, config, trainer.tokenizer)\n        trainer.add_prompt_pipeline(pipeline)\n\n        if eval_prompts is None:\n            eval_prompts = prompts[:batch_size]\n\n        trainer.make_experience(config.method.num_rollouts)\n\n    # Offline training from the collected samples (e.g. SFT, ILQL)\n    elif samples:\n        if rewards:\n            if len(samples) != len(rewards):\n                raise ValueError(f\"Number of samples {len(samples)} should match the number of rewards {len(rewards)}\")\n\n        if eval_prompts is None:\n            eval_prompts = [trainer.tokenizer.bos_token] * batch_size\n\n        if rewards:\n            trainer.make_experience(samples, rewards, config.train.seq_length)\n        else:\n            trainer.store = get_pipeline(config.train.pipeline)(samples, config, trainer.tokenizer)\n\n    else:\n        raise ValueError(\"Either `samples` or `reward_fn` should be given for training\")\n\n    eval_pipeline = get_pipeline(config.train.pipeline)(eval_prompts, config, trainer.tokenizer)\n    trainer.add_eval_pipeline(eval_pipeline)\n\n    trainer.learn()\n\n\ndef main():\n    args = get_parser()\n    logger.info(f\"Parameters: {args}\")\n\n    set_seed(args.seed)\n\n    # load tokenizer\n    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, use_cache=False, trust_remote_code=True)\n    # tokenizer.pad_token = tokenizer.eos_token\n    tokenizer.padding_side = \"left\" # PS: padding side does affect output of reward model\n\n    # load reward model\n    if \"pangu\" in args.reward_model_path:\n        model = AutoModelForCausalLM.from_pretrained(args.reward_model_path, use_cache=False, trust_remote_code=True)\n        model.resize_token_embeddings(tokenizer.vocab_size)\n        # model.config.end_token_id = tokenizer.eos_token_id\n        # model.config.pad_token_id = tokenizer.pad_token_id\n        # model.config.bos_token_id = tokenizer.bos_token_id\n        # model.config.eos_token_id = tokenizer.eos_token_id\n        model.config.lora_rank = args.lora_rank\n        model.config.lora_alpha = args.lora_alpha\n        model.config.lora_train_bias = args.lora_train_bias\n        model.config.target_modules = \"q_proj,k_proj,v_proj\"\n        model.config.task_type = \"CAUSAL_LM\"\n        # Initialize the reward model from the (supervised) fine-tuned SFT model\n        reward_model = RewardModel(model.config, model.transformer, tokenizer)\n        # reward_model = RewardModelWithLoRA(model.config, model.transformer, tokenizer)\n    elif \"chatglm\" in args.reward_model_path:\n        model = AutoModelForSeq2SeqLM.from_pretrained(args.reward_model_path, trust_remote_code=True).half()\n        model.config.lora_rank = args.lora_rank\n        model.config.lora_alpha = args.lora_alpha\n        model.config.lora_train_bias = args.lora_train_bias\n        model.config.target_modules = \"query_key_value\"\n        model.config.task_type = \"SEQ_2_SEQ_LM\"\n        # Initialize the reward model from the (supervised) fine-tuned SFT model\n        reward_model = RewardModel(model.config, model, tokenizer)\n        # reward_model = RewardModelWithLoRA(model.config, model.glm, tokenizer)\n    elif \"glm\" in args.reward_model_path:\n        model = AutoModelForSeq2SeqLM.from_pretrained(args.reward_model_path, trust_remote_code=True)\n        model.config.lora_rank = args.lora_rank\n        model.config.lora_alpha = args.lora_alpha\n        model.config.lora_train_bias = args.lora_train_bias\n        model.config.target_modules = \"query_key_value\"\n        model.config.task_type = \"SEQ_2_SEQ_LM\"\n        # Initialize the reward model from the (supervised) fine-tuned SFT model\n        reward_model = RewardModel(model.config, model.glm, tokenizer)\n        # reward_model = RewardModelWithLoRA(model.config, model.glm, tokenizer)\n    else:\n        raise ValueError(f\"Unsupported model name: {args.reward_model_path}\")\n    assert model.config.pad_token_id == tokenizer.pad_token_id\n\n    if args.reward_checkpoint is not None:\n        checkpoints = glob.glob(args.reward_checkpoint.replace(\"star\", \"*\"))\n        st = dict()\n        for checkpoint in checkpoints:\n            st.update(torch.load(checkpoint, map_location=\"cpu\"))\n        res = reward_model.load_state_dict(st, strict=False)\n\n    device = torch.device(f\"cuda:{args.local_rank}\") if torch.cuda.is_available() else torch.device(\"cpu\")\n    # reward_model.half()\n    reward_model.eval()\n    reward_model.to(device)\n    logger.info(f\"Finish loading reward model from {args.reward_checkpoint}\")\n\n    def reward_fn(samples, **kwargs):\n        scores_list = []\n        for i in range(0, len(samples), ppo_config.train.batch_size):\n            input_ids_list = []\n            attention_mask_list = []\n            position_ids_list = []\n            for sample in samples[i: i + ppo_config.train.batch_size]:\n                prompt, pred = sample.split(tokenizer.sep_token, maxsplit=1)\n                logger.debug(f\"prompt: {prompt}, pred: {pred}\")\n                if \"pangu\" in ppo_config.model.model_path:\n                    encodings_dict = tokenizer(prompt, pred, max_length=ppo_config.train.seq_length,\n                                               truncation=\"longest_first\", padding=\"max_length\", return_tensors=\"pt\",\n                                               return_token_type_ids=False)\n                    input_ids_list.append(encodings_dict[\"input_ids\"])\n                    attention_mask_list.append(encodings_dict[\"attention_mask\"])\n                elif \"chatglm\" in ppo_config.model.model_path:\n                    encoded_dict = tokenizer(prompt, pred, max_length=ppo_config.train.seq_length, return_tensors=\"pt\",\n                                             truncation=\"longest_first\", padding=\"max_length\")\n                    input_ids_list.append(encoded_dict[\"input_ids\"][0])\n                elif \"glm\" in ppo_config.model.model_path:\n                    # TODO: to be modified for and tested against glm\n                    encoded_prompt = tokenizer(prompt, tokenizer.mask_token)\n                    prompt_length = len(encoded_prompt['input_ids'])\n                    label_length = len(tokenizer.tokenize(pred))\n                    if prompt_length + label_length > ppo_config.train.seq_length:\n                        num_tokens_to_remove = prompt_length + label_length - ppo_config.train.seq_length\n                        for _ in range(num_tokens_to_remove):\n                            if prompt_length > label_length:\n                                prompt_length -= 1\n                            else:\n                                label_length -= 1\n                    else:\n                        label_length = ppo_config.train.seq_length - prompt_length\n                    assert prompt_length > 0\n                    assert label_length > 0\n                    assert prompt_length + label_length <= ppo_config.train.seq_length\n                    encoded_dict = tokenizer(prompt, tokenizer.mask_token,\n                                             max_length=prompt_length, truncation=\"only_first\",\n                                             return_tensors=\"pt\", return_attention_mask=True,\n                                             return_token_type_ids=False)\n                    encoded_dict = tokenizer.build_inputs_for_generation(encoded_dict, targets=pred,\n                                                                         max_gen_length=label_length, padding=True)\n                    input_ids_list.append(encoded_dict[\"input_ids\"][0])\n                    attention_mask_list.append(encoded_dict[\"attention_mask\"][0])\n                    position_ids_list.append(encoded_dict[\"position_ids\"][0])\n                else:\n                    raise ValueError(f\"Unsupported model type: {ppo_config.model.model_path}\")\n            # encodings_dict = tokenizer(\n            #     sub_samples,\n            #     max_length=ppo_config.train.seq_length,\n            #     truncation=\"longest_first\",\n            #     padding=\"max_length\",\n            #     return_tensors=\"pt\",\n            # )\n            input_ids = torch.stack(input_ids_list, dim=0).to(device)\n            attention_mask = torch.stack(attention_mask_list, dim=0).to(device) if len(attention_mask_list) > 0 else None\n            position_ids = torch.stack(position_ids_list, dim=0).to(device) if len(position_ids_list) > 0 else None\n            with torch.no_grad():\n                sub_scores = reward_model(input_ids, attention_mask, position_ids)\n            scores_list.append(sub_scores[\"chosen_reward\"])\n\n        scores = torch.cat(scores_list, dim=0)\n\n        return scores\n\n    # load ppo config\n    ppo_config = TRLConfig.load_yaml(os.path.join(RESOURCE_PATH, \"config\", \"ppo_model\", args.ppo_config))\n    ppo_config.train.epochs = args.num_epochs\n    ppo_config.train.seq_length = args.max_length\n    ppo_config.train.batch_size = args.train_batch_size\n    ppo_config.train.checkpoint_dir = args.output_dir\n    ppo_config.train.checkpoint_interval = args.save_steps\n    ppo_config.train.eval_interval = args.eval_steps\n    ppo_config.model.num_layers_unfrozen = args.num_layers_unfrozen\n    ppo_config.model.model_path = args.sft_model_path\n    ppo_config.tokenizer.tokenizer_path = args.tokenizer_path\n    ppo_config.optimizer.kwargs['lr'] = args.learning_rate\n    ppo_config.optimizer.kwargs['weight_decay'] = args.weight_decay\n    ppo_config.method.chunk_size = args.eval_batch_size\n    ppo_config.train.lora_rank = args.lora_rank\n    ppo_config.train.lora_alpha = args.lora_alpha\n    ppo_config.train.lora_train_bias = args.lora_train_bias\n    logger.info(f\"PPO config: {ppo_config}\")\n\n    # load dataset\n    if args.do_train:\n        train_dataset = RLHFDataset.load_dataset(os.path.join(args.data_dir, args.train_filename))\n    else:\n        train_dataset = None\n    if args.do_eval:\n        dev_dataset = RLHFDataset.load_dataset(os.path.join(args.data_dir, args.eval_filename))\n    else:\n        dev_dataset = None\n\n    if args.do_train:\n        train(model_path=args.sft_model_path, reward_fn=reward_fn, prompts=train_dataset,\n              eval_prompts=dev_dataset, config=ppo_config)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/train_sft.py",
    "content": "\nimport sys\nsys.path.insert(0, \"/root/autodl-tmp/Code/RLHF\")\nsys.path.insert(0, \"/mnt/sfevol775196/sunzeye273/Code/chatgpt\")\n# sys.path.insert(0, \"/mnt/share-pa002-vol682688-prd/sunzeye273/Code/chatgpt\")\nsys.path.insert(0, \"/mnt/pa002-28359-vol543625-private/Code/chatgpt\")\nimport os\nimport argparse\nimport evaluate\nimport torch\n\nfrom tqdm import tqdm\nfrom transformers import (\n    Trainer,\n    TrainingArguments,\n    default_data_collator,\n)\n\nfrom src.utils import RESOURCE_PATH, load_tokenizer_and_model, load_checkpoint\nfrom src.data.data import SFTDataset, chatglm2_encode, chatglm3_encode\nfrom src.utils.file_utils import set_seed, print_rank_0\n# from src.models import convert_to_lora_recursively\n\n\n# Create a preprocessing function to extract out the proper logits from the model output\ndef preprocess_logits_for_metrics(logits, labels):\n    if isinstance(logits, tuple):\n        logits = logits[0]\n\n    return logits.argmax(dim=-1)\n\n\ndef get_parser():\n    parser = argparse.ArgumentParser()\n    \n    parser.add_argument(\"--data_dir\", type=str, required=True)\n    parser.add_argument(\"--output_dir\", type=str, required=True)\n    parser.add_argument(\"--tokenizer_path\", type=str, required=True)\n    parser.add_argument(\"--model_name_or_path\", type=str, required=True)\n\n    parser.add_argument(\"--seed\", type=int, default=42)\n    parser.add_argument(\"--local_rank\", type=int, default=0)\n    parser.add_argument(\"--max_length\", type=int, default=1024)\n    parser.add_argument(\"--max_length_generation\", type=int, default=None)\n    parser.add_argument(\"--bits\", type=int, default=32,\n                        help=\"bits used to load model, including: 32, 16, 8, 4\")\n    parser.add_argument(\"--device_map\", type=str, default=None, help=\"device map to allocate model,\"\n                                                                     \"[None] means cpu\"\n                                                                     \"[0, 1, 2, ...], number means single-card\"\n                                                                     \"[auto, balanced, balanced_low_0] means multi-card\")\n    parser.add_argument(\"--low_cpu_mem_usage\", action=\"store_true\", help=\"whether to enable low cpu memory usage\"\n                                                                         \"when loading model\")\n    # train\n    parser.add_argument(\"--do_train\", action=\"store_true\")\n    parser.add_argument(\"--train_filename\", type=str, default=None)\n    parser.add_argument(\"--concat_samples\", action=\"store_true\")\n    parser.add_argument(\"--num_epochs\", type=int, default=1)\n    parser.add_argument(\"--learning_rate\", type=float, default=1e-6)\n    parser.add_argument(\"--lr_scheduler_type\", type=str, default=\"cosine\",\n                        help=\"transformers.trainer_utils.SchedulerType, including:\"\n                             \"linear, cosine, cosine_with_restarts, polynomial, constant,\"\n                             \"constant_with_warmup\")\n    parser.add_argument(\"--train_batch_size\", type=int, default=4)\n    parser.add_argument(\"--weight_decay\", type=float, default=0.1)\n    parser.add_argument(\"--warmup_ratio\", type=int, default=0.1)\n    parser.add_argument(\"--logging_steps\", type=int, default=100)\n    parser.add_argument(\"--save_strategy\", type=str, default=\"steps\",\n                        help='- `\"no\"`: No save is done during training.'\n                             '- `\"epoch\"`: Save is done at the end of each epoch.'\n                             '- `\"steps\"`: Save is done every `save_steps`.')\n    parser.add_argument(\"--save_steps\", type=int, default=1000)\n    parser.add_argument(\"--save_total_limit\", type=int, default=2)\n    parser.add_argument(\"--metric_for_best_model\", type=str, default=None)\n    parser.add_argument(\"--gradient_accumulation_steps\", type=int, default=8)\n    parser.add_argument(\"--gradient_checkpointing\", action=\"store_true\",\n                        help=\"If True, use gradient checkpointing to save memory at the expense of slower backward pass.\")\n    parser.add_argument(\"--deepspeed_config\", type=str, default=None)\n    parser.add_argument(\"--lora_rank\", type=int, default=0)\n    parser.add_argument(\"--lora_alpha\", type=int, default=1)\n    parser.add_argument(\"--lora_train_bias\", type=str, default=\"none\")\n    # eval\n    parser.add_argument(\"--do_eval\", action=\"store_true\")\n    parser.add_argument(\"--eval_filename\", type=str, default=None)\n    parser.add_argument(\"--eval_batch_size\", type=int, default=4)\n    parser.add_argument(\"--evaluation_strategy\", type=str, default=\"steps\",\n                        help='- `\"no\"`: No evaluation is done during training.'\n                             '- `\"steps\"`: Evaluation is done (and logged) every `eval_steps`.'\n                             '- `\"epoch\"`: Evaluation is done at the end of each epoch.')\n    parser.add_argument(\"--eval_steps\", type=int, default=100)\n    parser.add_argument(\"--eval_accumulation_steps\", type=int, default=1)\n    # pred\n    parser.add_argument(\"--do_pred\", action=\"store_true\")\n    parser.add_argument(\"--checkpoint\", type=str, default=None)\n    parser.add_argument(\"--test_filename\", type=str, default=None)\n    parser.add_argument(\"--output_filename\", type=str, default=None)\n    parser.add_argument(\"--do_sample\", action=\"store_true\")\n    parser.add_argument(\"--num_return_sequences\", type=int, default=1)\n    parser.add_argument(\"--top_k\", type=int, default=None)\n    parser.add_argument(\"--top_p\", type=float, default=None)\n    parser.add_argument(\"--temperature\", type=float, default=None)\n\n    args = parser.parse_args()\n    \n    return args\n\n\ndef main():\n    args = get_parser()\n    print_rank_0(f\"Parameters: {args}\")\n\n    set_seed(args.seed)\n\n    # load tokenizer and model\n    tokenizer, model, eos_token_id = load_tokenizer_and_model(args)\n\n    if args.checkpoint is not None:\n        load_checkpoint(args, model, strict=False)\n\n    print_rank_0(f\"Finished loading model and tokenizer\")\n\n    # Set up the datasets\n    if args.do_train:\n        train_dataset = SFTDataset(args, os.path.join(args.data_dir, args.train_filename),\n                                   tokenizer, concat_samples=args.concat_samples)\n    else:\n        train_dataset = None\n    if args.do_eval:\n        dev_dataset = SFTDataset(args, os.path.join(args.data_dir, args.eval_filename),\n                                 tokenizer, concat_samples=False)\n    else:\n        dev_dataset = None\n    if args.do_pred:\n        test_dataset = SFTDataset(args, os.path.join(args.data_dir, args.test_filename),\n                                  tokenizer, concat_samples=False)\n    else:\n        test_dataset = None\n\n    if args.do_train:\n        if torch.cuda.is_available():\n            bf16 = torch.cuda.get_device_capability()[0] >= 8\n            fp16 = not bf16\n        else:\n            fp16 = False\n            bf16 = False\n        # training arguments\n        deepspeed_config = os.path.join(RESOURCE_PATH, \"config\", \"deepspeed\", args.deepspeed_config) if args.deepspeed_config is not None else None\n        training_args = TrainingArguments(\n            output_dir=args.output_dir,\n            no_cuda=not torch.cuda.is_available(),\n            seed=args.seed,\n            data_seed=args.seed,\n            local_rank=args.local_rank,\n            do_train=args.do_train,\n            num_train_epochs=args.num_epochs,\n            learning_rate=args.learning_rate,\n            lr_scheduler_type=args.lr_scheduler_type,\n            per_device_train_batch_size=args.train_batch_size,\n            gradient_accumulation_steps=args.gradient_accumulation_steps,\n            warmup_ratio=args.warmup_ratio,\n            weight_decay=args.weight_decay,\n            half_precision_backend=\"auto\",\n            fp16=fp16,\n            bf16=bf16,\n            optim=\"paged_adamw_8bit\",\n            # adam_beta1=0.9,\n            # adam_beta2=0.95,\n            save_strategy=args.save_strategy,\n            save_steps=args.save_steps,\n            save_total_limit=args.save_total_limit,\n            metric_for_best_model=args.metric_for_best_model,\n            greater_is_better=True,\n            logging_steps=args.logging_steps,\n            report_to=[\"tensorboard\"],\n            deepspeed=deepspeed_config,\n            gradient_checkpointing=args.gradient_checkpointing,\n            do_eval=args.do_eval,\n            evaluation_strategy=args.evaluation_strategy,\n            eval_steps=args.eval_steps,\n            eval_accumulation_steps=args.eval_accumulation_steps,\n            per_device_eval_batch_size=args.eval_batch_size,\n            # do_predict=args.do_pred,\n            # use_legacy_prediction_loop=args.do_pred,\n        )\n        print_rank_0(f\"Training Arguments: {training_args}\")\n\n        # Set up the metric\n        rouge = evaluate.load(\"rouge\")\n\n        def compute_metrics(eval_preds):\n            labels_ids = eval_preds.label_ids\n            pred_ids = eval_preds.predictions\n            pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n            label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)\n            result = rouge.compute(predictions=pred_str, references=label_str)\n\n            return result\n\n        # Prepare the trainer and start training\n        trainer = Trainer(\n            model=model,\n            args=training_args,\n            train_dataset=train_dataset,\n            eval_dataset=dev_dataset,\n            compute_metrics=compute_metrics,\n            data_collator=default_data_collator,\n            preprocess_logits_for_metrics=preprocess_logits_for_metrics,\n        )\n        # model.config.use_cache = False  # silence the warnings. Please re-enable for inference!\n        trainer.train()\n        trainer.save_model(args.output_dir)\n\n    elif args.do_eval:\n        # res = trainer.evaluate(eval_dataset=dev_dataset)\n        # print_rank_0(res)\n        pass\n\n    if args.do_pred:\n        model.eval()\n        device = f\"cuda:{args.local_rank}\" if torch.cuda.is_available() and args.device_map is not None else \"cpu\"\n        tokenizer.padding_side = \"left\"\n        with open(os.path.join(args.output_dir, args.output_filename), \"w\", encoding=\"utf-8\") as w:\n            w.write(\"\\t\".join([\"prompt\"]+[f\"model_answer_{i}\" for i in range(args.num_return_sequences)])+\"\\n\")\n            for test_data in tqdm(test_dataset.post_list, desc=\"Prediction\"):\n                prompt = test_data['prompt']\n                prefix = test_data.get('prefix', \"\")\n                system = test_data.get('system', \"\")\n                if \"chatglm3\" in args.model_name_or_path.lower():\n                    _, _, prompt_ids = chatglm3_encode(tokenizer, prompt, None, system, args.max_length)\n                    input_ids = torch.tensor([prompt_ids], dtype=torch.long, device=device)\n                    outputs = model.generate(input_ids=input_ids,\n                                             max_new_tokens=args.max_length_generation,\n                                             eos_token_id=eos_token_id,\n                                             pad_token_id=tokenizer.pad_token_id,\n                                             do_sample=args.do_sample,\n                                             num_return_sequences=args.num_return_sequences,\n                                             top_k=args.top_k,\n                                             top_p=args.top_p,\n                                             temperature=args.temperature)\n                    prompt_length = len(prompt_ids)\n                elif \"chatglm2\" in args.model_name_or_path.lower():\n                    _, _, prompt_ids = chatglm2_encode(tokenizer, prompt, None, system, args.max_length)\n                    input_ids = torch.tensor([prompt_ids], dtype=torch.long, device=device)\n                    outputs = model.generate(input_ids=input_ids,\n                                             max_new_tokens=args.max_length_generation,\n                                             eos_token_id=eos_token_id,\n                                             pad_token_id=tokenizer.pad_token_id,\n                                             do_sample=args.do_sample,\n                                             num_return_sequences=args.num_return_sequences,\n                                             top_k=args.top_k,\n                                             top_p=args.top_p,\n                                             temperature=args.temperature)\n                    prompt_length = len(prompt_ids)\n                elif \"chatglm\" in args.model_name_or_path.lower():\n                    prompt = \"\\n\\n\".join((system, prompt))\n                    encoded_prompt = tokenizer(prompt)\n                    prompt_length = len(encoded_prompt['input_ids'])\n                    inputs = tokenizer(prompt,\n                                       max_length=min(prompt_length, args.max_length),\n                                       truncation=\"only_first\",\n                                       return_tensors=\"pt\")\n                    # max_gen_length = args.max_length - encoded_dict['input_ids'].shape[1]\n                    # inputs = tokenizer.build_inputs_for_generation(encoded_dict,\n                    #                                                max_gen_length=max_gen_length, padding=True)\n                    inputs = inputs.to(device)\n                    outputs = model.generate(**inputs,\n                                             max_new_tokens=args.max_length_generation,\n                                             eos_token_id=eos_token_id,\n                                             pad_token_id=tokenizer.pad_token_id,\n                                             do_sample=args.do_sample,\n                                             num_return_sequences=args.num_return_sequences,\n                                             top_k=args.top_k,\n                                             top_p=args.top_p,\n                                             temperature=args.temperature)\n                    prompt_length = len(inputs['input_ids'][0])\n                elif \"glm\" in args.model_name_or_path.lower():\n                    encoded_prompt = tokenizer(prompt, prefix + tokenizer.mask_token)\n                    prompt_length = len(encoded_prompt['input_ids'])\n                    encoded_dict = tokenizer(prompt, prefix + tokenizer.mask_token,\n                                             max_length=min(prompt_length, args.max_length),\n                                             truncation=\"only_first\",\n                                             return_tensors=\"pt\",\n                                             return_token_type_ids=False)\n                    max_gen_length = args.max_length - encoded_dict['input_ids'].shape[1]\n                    inputs = tokenizer.build_inputs_for_generation(encoded_dict,\n                                                                   max_gen_length=max_gen_length, padding=True)\n                    inputs = inputs.to(device)\n                    outputs = model.generate(**inputs,\n                                             max_new_tokens=min(args.max_length_generation, max_gen_length),\n                                             eos_token_id=eos_token_id,\n                                             pad_token_id=tokenizer.pad_token_id,\n                                             do_sample=args.do_sample,\n                                             num_return_sequences=args.num_return_sequences,\n                                             top_k=args.top_k,\n                                             top_p=args.top_p,\n                                             temperature=args.temperature)\n                    prompt_length = len(inputs['input_ids'][0])\n                else:\n                    inputs = tokenizer(prompt, tokenizer.sep_token + prefix, max_length=args.max_length,\n                                       truncation=\"only_first\", add_special_tokens=False,\n                                       return_tensors=\"pt\", return_token_type_ids=False)\n                    # inputs = tokenizer(prompt, add_special_tokens=False, return_token_type_ids=False, return_tensors=\"pt\")\n                    inputs = inputs.to(device)\n                    outputs = model.generate(**inputs,\n                                             max_new_tokens=args.max_length_generation,\n                                             pad_token_id=tokenizer.pad_token_id,\n                                             do_sample=args.do_sample,\n                                             num_return_sequences=args.num_return_sequences,\n                                             top_k=args.top_k,\n                                             top_p=args.top_p,\n                                             temperature=args.temperature)\n                    prompt_length = len(inputs['input_ids'][0])\n                results = tokenizer.batch_decode([output[prompt_length:] for output in outputs], skip_special_tokens=True)\n                w.write(\"\\t\".join([prompt]+results)+\"\\n\")\n\n    \nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/train_sft_wo_trainer.py",
    "content": "\nimport sys\nsys.path.insert(0, \"/root/autodl-tmp/Code/RLHF\")\nsys.path.insert(0, \"/mnt/sfevol775196/sunzeye273/Code/chatgpt\")\n# sys.path.insert(0, \"/mnt/share-pa002-vol682688-prd/sunzeye273/Code/chatgpt\")\nsys.path.insert(0, \"/mnt/pa002-28359-vol543625-private/Code/chatgpt\")\nimport os\nimport argparse\nimport evaluate\nimport torch\nimport json\nimport deepspeed\n\nfrom tqdm import tqdm\nfrom torch.utils.data import RandomSampler, DistributedSampler, DataLoader\nfrom transformers.deepspeed import HfDeepSpeedConfig\n\n\nfrom src.utils import logger, RESOURCE_PATH, load_tokenizer_and_model, load_checkpoint\nfrom src.data.data import SFTDataset\nfrom src.utils.file_utils import set_seed, print_rank_0\nfrom src.utils.modeling_utils import rotate_checkpoints, save_zero_three_model\n# from src.models import convert_to_lora_recursively\n\n\n# Create a preprocessing function to extract out the proper logits from the model output\ndef preprocess_logits_for_metrics(logits, labels):\n    if isinstance(logits, tuple):\n        logits = logits[0]\n\n    return logits.argmax(dim=-1)\n\n\ndef get_parser():\n    parser = argparse.ArgumentParser()\n    \n    parser.add_argument(\"--data_dir\", type=str, required=True)\n    parser.add_argument(\"--output_dir\", type=str, required=True)\n    parser.add_argument(\"--model_name_or_path\", type=str, required=True)\n\n    parser.add_argument(\"--seed\", type=int, default=42)\n    parser.add_argument(\"--local_rank\", type=int, default=0)\n    parser.add_argument(\"--max_length\", type=int, default=1024)\n    parser.add_argument(\"--bits\", type=int, default=16)\n    parser.add_argument(\"--device_map\", type=str, default=None, help=\"device map to allocate model,\"\n                                                                     \"[None] means cpu\"\n                                                                     \"[0, 1, 2, ...], number means single-card\"\n                                                                     \"[auto, balanced, balanced_low_0] means multi-card\")\n    parser.add_argument(\"--low_cpu_mem_usage\", action=\"store_true\", help=\"whether to enable low cpu memory usage\"\n                                                                         \"when loading model\")\n    parser.add_argument(\"--max_length_generation\", type=int, default=None)\n    # train\n    parser.add_argument(\"--do_train\", action=\"store_true\")\n    parser.add_argument(\"--train_filename\", type=str, default=None)\n    parser.add_argument(\"--concat_samples\", action=\"store_true\")\n    parser.add_argument(\"--num_epochs\", type=int, default=1)\n    parser.add_argument(\"--learning_rate\", type=float, default=1e-6)\n    parser.add_argument(\"--lr_scheduler_type\", type=str, default=\"OneCycle\",\n                        help=\"deepspeed scheduler types, including:\"\n                             \"LRRangeTest, OneCycle, WarmupLR, WarmupDecayLR\")\n    parser.add_argument(\"--train_batch_size\", type=int, default=4)\n    parser.add_argument(\"--weight_decay\", type=float, default=0.1)\n    parser.add_argument(\"--warmup_steps\", type=int, default=100)\n    parser.add_argument(\"--logging_steps\", type=int, default=100)\n    parser.add_argument(\"--save_strategy\", type=str, default=\"steps\",\n                        help='- `\"no\"`: No save is done during training.'\n                             '- `\"epoch\"`: Save is done at the end of each epoch.'\n                             '- `\"steps\"`: Save is done every `save_steps`.')\n    parser.add_argument(\"--save_steps\", type=int, default=1000)\n    parser.add_argument(\"--save_total_limit\", type=int, default=2)\n    parser.add_argument(\"--metric_for_best_model\", type=str, default=None)\n    parser.add_argument(\"--max_grad_norm\", type=float, default=1.0)\n    parser.add_argument(\"--gradient_accumulation_steps\", type=int, default=8)\n    parser.add_argument(\"--gradient_checkpointing\", action=\"store_true\",\n                        help=\"If True, use gradient checkpointing to save memory at the expense of slower backward pass.\")\n    parser.add_argument(\"--deepspeed_config\", type=str, default=None)\n    parser.add_argument(\"--lora_rank\", type=int, default=0)\n    parser.add_argument(\"--lora_alpha\", type=int, default=1)\n    parser.add_argument(\"--lora_train_bias\", type=str, default=\"none\")\n    # eval\n    parser.add_argument(\"--do_eval\", action=\"store_true\")\n    parser.add_argument(\"--eval_filename\", type=str, default=None)\n    parser.add_argument(\"--eval_batch_size\", type=int, default=4)\n    parser.add_argument(\"--evaluation_strategy\", type=str, default=\"steps\",\n                        help='- `\"no\"`: No evaluation is done during training.'\n                             '- `\"steps\"`: Evaluation is done (and logged) every `eval_steps`.'\n                             '- `\"epoch\"`: Evaluation is done at the end of each epoch.')\n    parser.add_argument(\"--eval_steps\", type=int, default=100)\n    parser.add_argument(\"--eval_accumulation_steps\", type=int, default=1)\n    # pred\n    parser.add_argument(\"--do_pred\", action=\"store_true\")\n    parser.add_argument(\"--checkpoint\", type=str, default=None)\n    parser.add_argument(\"--test_filename\", type=str, default=None)\n    parser.add_argument(\"--output_filename\", type=str, default=None)\n    parser.add_argument(\"--do_sample\", action=\"store_true\")\n    parser.add_argument(\"--num_return_sequences\", type=int, default=1)\n    parser.add_argument(\"--top_k\", type=int, default=None)\n    parser.add_argument(\"--top_p\", type=float, default=None)\n    parser.add_argument(\"--temperature\", type=float, default=None)\n\n    args = parser.parse_args()\n    \n    return args\n\n\ndef main():\n    args = get_parser()\n    print_rank_0(f\"Parameters: {args}\")\n\n    set_seed(args.seed)\n\n    torch.cuda.set_device(args.local_rank)\n    device = torch.device(\"cuda\", args.local_rank)\n\n    # create HfDeepSpeedConfig [must be called before instantiating model]\n    if args.deepspeed_config is not None:\n        ds_config_filename = os.path.join(RESOURCE_PATH, \"config\", \"deepspeed\", args.deepspeed_config)\n        ds_config = json.load(open(ds_config_filename, \"r\", encoding=\"utf-8\"))\n        # ds_config[\"steps_per_print\"] = args.logging_steps\n        ds_config[\"train_micro_batch_size_per_gpu\"] = args.train_batch_size\n        ds_config[\"gradient_accumulation_steps\"] = args.gradient_accumulation_steps\n        ds_config[\"gradient_clipping\"] = args.max_grad_norm\n        # TODO: before calling dist init, world size is always 1, therefore ds_config['train_batch_size'] cannot multiply world size\n        ds_config['train_batch_size'] = args.train_batch_size * args.gradient_accumulation_steps #* torch.cuda.device_count()\n        # TODO: assuming hidden_size=4096\n        ds_config[\"zero_optimization\"][\"reduce_bucket_size\"] = 4096 * 4096\n        ds_config[\"zero_optimization\"][\"stage3_prefetch_bucket_size\"] = 0.9 * 4096 * 4096\n        ds_config[\"zero_optimization\"][\"stage3_param_persistence_threshold\"] = 10 * 4096\n        if torch.cuda.is_available():\n            bf16 = torch.cuda.get_device_capability()[0] >= 8\n            fp16 = False if bf16 else True\n        else:\n            fp16 = False\n            bf16 = False\n        ds_config[\"fp16\"][\"enabled\"] = fp16\n        ds_config[\"bf16\"][\"enabled\"] = bf16\n        ds_config[\"optimizer\"][\"params\"] = {\n                \"lr\": args.learning_rate,\n                \"betas\": [0.9, 0.95],\n                \"eps\": 1e-8,\n                \"weight_decay\": args.weight_decay\n            }\n        assert ds_config[\"scheduler\"]['type'] == args.lr_scheduler_type\n        ds_config[\"scheduler\"][\"params\"] = {\n            \"cycle_min_lr\": 0,\n            \"cycle_max_lr\": args.learning_rate,\n            \"cycle_first_step_size\": args.warmup_steps\n        }\n        dschf = HfDeepSpeedConfig(ds_config)  # keep this object alive\n\n    # load tokenizer and model\n    tokenizer, model, eos_token_id = load_tokenizer_and_model(args, with_trainer=False)\n\n    if args.checkpoint is not None:\n        load_checkpoint(args, model, strict=False)\n\n    print_rank_0(f\"Finished loading model and tokenizer\")\n\n    # Set up the datasets\n    if args.do_train:\n        train_dataset = SFTDataset(args, os.path.join(args.data_dir, args.train_filename),\n                                   tokenizer, concat_samples=args.concat_samples)\n    else:\n        train_dataset = None\n    if args.do_eval:\n        dev_dataset = SFTDataset(args, os.path.join(args.data_dir, args.eval_filename),\n                                 tokenizer, concat_samples=False)\n    else:\n        dev_dataset = None\n    if args.do_pred:\n        test_dataset = SFTDataset(args, os.path.join(args.data_dir, args.test_filename),\n                                  tokenizer, concat_samples=False)\n    else:\n        test_dataset = None\n\n    if args.do_train:\n        # # Optimizer\n        # AdamOptimizer = DeepSpeedCPUAdam if \"3\" in args.deepspeed_config else FusedAdam\n        # optim_params = get_optimizer_grouped_parameters(\n        #     actor_model, self.args.actor_weight_decay)\n        # optim = AdamOptimizer(optim_params,\n        #                       lr=self.args.actor_learning_rate,\n        #                       betas=(0.9, 0.95))\n        #\n        # # LR Scheduler\n        # lr_scheduler = get_scheduler(\n        #     name=self.args.lr_scheduler_type,\n        #     optimizer=optim,\n        #     num_warmup_steps=self.args.warmup_steps,\n        #     num_training_steps=self.num_total_iters,\n        # )\n\n        # deepspeed initialize\n        ds_config['train_batch_size'] = args.train_batch_size * args.gradient_accumulation_steps * torch.cuda.device_count()\n        model_engine, *_ = deepspeed.initialize(model=model,\n                                                # optimizer=optim,\n                                                # lr_scheduler=lr_scheduler,\n                                                config=ds_config)\n        # create data loader\n        train_sampler = DistributedSampler(train_dataset)\n        train_dataloader = DataLoader(\n            train_dataset,\n            # collate_fn=data_collator,\n            sampler=train_sampler,\n            batch_size=args.train_batch_size)\n\n        # training\n        model_engine.train()\n        if args.gradient_checkpointing:\n            model_engine.module.gradient_checkpointing_enable()\n        global_step = 0\n        best_metric = None\n        best_model_checkpoint = None\n        if args.do_eval:\n            assert args.eval_steps <= args.save_steps and args.save_steps % args.eval_steps == 0, \\\n                f\"save steps should be greater than eval steps and be a multiple of eval steps\"\n            eval_results = eval(global_step)\n            print_rank_0(f\"Epoch-0, Gloal step-{global_step}, Evaluation result: {eval_results}\")\n            if args.metric_for_best_model is not None:\n                assert args.metric_for_best_model in eval_results, \\\n                    f\"{args.metric_for_best_model} is not a valid metric, \" \\\n                    f\"please choose from the following metrics: {eval_results.keys()}\"\n        for epoch in range(args.num_epochs):\n            print_rank_0(f\"Beginning of Epoch {epoch+1}/{args.num_epochs}\")\n            for step, batch in enumerate(train_dataloader):\n                batch = {k: v.to(device) for k, v in batch.items()}\n                logger.debug(f\"batch keys: {batch.keys()}\")\n                output = model_engine(**batch)\n                model_engine.backward(output.loss)\n                model_engine.step()\n                global_step += 1\n                if global_step % args.logging_steps == 0:\n                    print_rank_0(f\"Epoch-{epoch+1}, Gloal step-{global_step}, loss: {output.loss}\")\n                if args.do_eval and global_step % args.eval_steps == 0:\n                    eval_results = eval(global_step)\n                    print_rank_0(f\"Epoch-{epoch+1}, Gloal step-{global_step}, Evaluation result: {eval_results}\")\n                if global_step % args.save_steps == 0:\n                    output_dir = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                    if args.do_eval and args.metric_for_best_model is not None:\n                        if (\n                                best_metric is None or\n                                best_model_checkpoint is None or\n                                eval_results[args.metric_for_best_model] > best_metric\n                        ):\n                            best_metric = eval_results[args.metric_for_best_model]\n                            best_model_checkpoint = output_dir\n                    rotate_checkpoints(args.save_total_limit, use_mtime=True, output_dir=args.output_dir,\n                                       best_model_checkpoint=best_model_checkpoint)\n                    # save_zero_three_model(model_engine, args.local_rank,\n                    #                       save_dir=output_dir,\n                    #                       zero_stage=ds_config['zero_optimization']['stage'])\n                    # model_engine.save_16bit_model(output_dir)\n                    model_engine.save_checkpoint(args.output_dir, f\"checkpoint-{global_step}\")\n                    print_rank_0(f\"Finished saving checkpoint @Step-{global_step}\")\n\n        print_rank_0(f\"Finished training! epochs: {epoch+1}, steps: {global_step}\")\n\n        output_dir = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n        # save_zero_three_model(model_engine, args.local_rank,\n        #                       save_dir=output_dir,\n        #                       zero_stage=ds_config['zero_optimization']['stage'])\n        # model_engine.save_16bit_model(output_dir)\n        model_engine.save_checkpoint(args.output_dir, f\"checkpoint-{global_step}\")\n        print_rank_0(f\"Finished saving checkpoint @Step-{global_step}\")\n\n    elif args.do_eval:\n        pass\n\n    if args.do_pred:\n        model.eval()\n        device = f\"cuda:{args.local_rank}\" if torch.cuda.is_available() else \"cpu\"\n        model = model.to(device)\n        tokenizer.padding_side = \"left\"\n        with open(os.path.join(args.output_dir, args.output_filename), \"w\", encoding=\"utf-8\") as w:\n            w.write(\"\\t\".join([\"prompt\"]+[f\"model_answer_{i}\" for i in range(args.num_return_sequences)])+\"\\n\")\n            for test_data in tqdm(test_dataset.post_list, desc=\"Prediction\"):\n                prompt = test_data['prompt']\n                prefix = test_data['prefix']\n                # label = dev_data['label']\n                if \"chatglm\" in args.model_name_or_path.lower():\n                    encoded_prompt = tokenizer(prompt)\n                    prompt_length = len(encoded_prompt['input_ids'])\n                    inputs = tokenizer(prompt,\n                                       max_length=min(prompt_length, args.max_length),\n                                       truncation=\"only_first\",\n                                       return_tensors=\"pt\")\n                    # max_gen_length = args.max_length - encoded_dict['input_ids'].shape[1]\n                    # inputs = tokenizer.build_inputs_for_generation(encoded_dict,\n                    #                                                max_gen_length=max_gen_length, padding=True)\n                    inputs = inputs.to(device)\n                    outputs = model.generate(**inputs,\n                                             max_new_tokens=args.max_length_generation,\n                                             eos_token_id=tokenizer.eop_token_id,\n                                             pad_token_id=tokenizer.pad_token_id,\n                                             do_sample=args.do_sample,\n                                             num_return_sequences=args.num_return_sequences,\n                                             top_k=args.top_k,\n                                             top_p=args.top_p,\n                                             temperature=args.temperature)\n                elif \"glm\" in args.model_name_or_path.lower():\n                    encoded_prompt = tokenizer(prompt, prefix + tokenizer.mask_token)\n                    prompt_length = len(encoded_prompt['input_ids'])\n                    encoded_dict = tokenizer(prompt, prefix + tokenizer.mask_token,\n                                             max_length=min(prompt_length, args.max_length),\n                                             truncation=\"only_first\",\n                                             return_tensors=\"pt\",\n                                             return_token_type_ids=False)\n                    max_gen_length = args.max_length - encoded_dict['input_ids'].shape[1]\n                    inputs = tokenizer.build_inputs_for_generation(encoded_dict,\n                                                                   max_gen_length=max_gen_length, padding=True)\n                    inputs = inputs.to(device)\n                    outputs = model.generate(**inputs,\n                                             max_new_tokens=min(args.max_length_generation, max_gen_length),\n                                             eos_token_id=tokenizer.eop_token_id,\n                                             pad_token_id=tokenizer.pad_token_id,\n                                             do_sample=args.do_sample,\n                                             num_return_sequences=args.num_return_sequences,\n                                             top_k=args.top_k,\n                                             top_p=args.top_p,\n                                             temperature=args.temperature)\n                else:\n                    inputs = tokenizer(prompt, tokenizer.sep_token + prefix, max_length=args.max_length,\n                                       truncation=\"longest_first\", add_special_tokens=False,\n                                       return_tensors=\"pt\", return_token_type_ids=False)\n                    # inputs = tokenizer(prompt, add_special_tokens=False, return_token_type_ids=False, return_tensors=\"pt\")\n                    inputs = inputs.to(device)\n                    outputs = model.generate(**inputs,\n                                             max_new_tokens=args.max_length_generation,\n                                             pad_token_id=tokenizer.pad_token_id,\n                                             do_sample=args.do_sample,\n                                             num_return_sequences=args.num_return_sequences,\n                                             top_k=args.top_k,\n                                             top_p=args.top_p,\n                                             temperature=args.temperature)\n                results = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n                w.write(\"\\t\".join([prompt]+[result.split(prefix, maxsplit=1)[1] for result in results])+\"\\n\")\n\n    \nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/utils/__init__.py",
    "content": "\nfrom src.utils.config import *\nfrom src.utils.logger import logger\nfrom src.utils.nlp_utils import *\nfrom src.utils.file_utils import *\nfrom src.utils.loading import *\nfrom src.utils.method_configs import *\nfrom src.utils.modeling_utils import *"
  },
  {
    "path": "src/utils/config.py",
    "content": "\nimport os\nimport yaml\n\nfrom copy import deepcopy\nfrom dataclasses import dataclass, field\nfrom typing import Any, Dict, Optional, Set\n\nfrom src.utils.method_configs import MethodConfig, get_method, PPOConfig, SFTConfig, ILQLConfig\n\n\nROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))\nRESOURCE_PATH = os.path.join(ROOT_PATH, \"resources\")\n\n\n# -------- RLHF Config------- #\n\ndef merge(base: Dict, update: Dict, updated: Set) -> Dict:\n    \"Recursively updates a nested dictionary with new values\"\n    for k, v in base.items():\n        if k in update and isinstance(v, dict):\n            base[k] = merge(v, update[k], updated)\n            updated.add(k)\n        elif k in update:\n            base[k] = update[k]\n            updated.add(k)\n\n    return base\n\n\ndef _merge_dicts(base: Dict, update: Dict) -> Dict:\n    \"Merge two dictionaries recursively, returning a new dictionary.\"\n\n    base = deepcopy(base)\n\n    for k, v in update.items():\n        if isinstance(v, dict):\n            base[k] = _merge_dicts(base.get(k, {}), v)\n        else:\n            base[k] = v\n\n    return base\n\n\n@dataclass\nclass ModelConfig:\n    \"\"\"\n    Config for a model.\n\n    :param model_path: Path or name of the model (local or on huggingface hub)\n    :type model_path: str\n\n    :param model_arch_type: Type of model architecture. Either \"causal\" or \"seq2seq\"\n    :type model_arch_type: str\n\n    :param num_layers_unfrozen: Number of layers to unfreeze for fine-tuning.\n        -1 means all layers are unfrozen.\n    :type num_layers_unfrozen: int\n\n    :param delta_kwargs: Keyword arguments for instantiating OpenDelta models for delta-tuning.\n        Follow the `OpenDelta.AutoDeltaConfig` specification, e.g. for LoRA style tuning, set\n        the `delta_type` to `lora` and include the model specific hyper-parameters (e.g. `lora_r`)\n            {\"delta_type\": \"lora\", \"modified_modules\": \"all\", \"lora_r\": 8, \"lora_alpha\": 16, \"lora_dropout\": 0.0}\n        or in YAML format:\n            delta_kwargs:\n                delta_type: lora\n                modified_modules: \"all\"\n                lora_r: 8\n                lora_alpha: 16\n                lora_dropout: 0.0\n        See: https://opendelta.readthedocs.io/en/latest/modules/auto_delta.html#opendelta.auto_delta.AutoDeltaConfig\n    :type delta_kwargs: Optional[Dict[str, Any]]\n    \"\"\"\n\n    model_path: str\n    model_arch_type: str = \"causal\"\n    num_layers_unfrozen: int = -1\n    delta_kwargs: Optional[Dict[str, Any]] = None\n\n    @classmethod\n    def from_dict(cls, config: Dict[str, Any]):\n        return cls(**config)\n\n\n@dataclass\nclass TokenizerConfig:\n    \"\"\"\n    Config for a model.\n\n    :param tokenizer_path: Path or name of the tokenizer (local or on huggingface hub)\n    :type tokenizer_path: str\n\n    :param padding_side: Padding side\n    :type padding_path: str\n\n    :param truncation_side: Truncation side\n    :type truncation_side: str\n    \"\"\"\n\n    tokenizer_path: str\n    padding_side: str = \"left\"\n    truncation_side: str = \"right\"\n\n    @classmethod\n    def from_dict(cls, config: Dict[str, Any]):\n        return cls(**config)\n\n\n@dataclass\nclass OptimizerConfig:\n    \"\"\"\n    Config for an optimizer.\n\n    :param name: Name of the optimizer\n    :type name: str\n\n    :param kwargs: Keyword arguments for the optimizer (e.g. lr, betas, eps, weight_decay)\n    :type kwargs: Dict[str, Any]\n    \"\"\"\n\n    name: str\n    kwargs: Dict[str, Any] = field(default_factory=dict)\n\n    @classmethod\n    def from_dict(cls, config: Dict[str, Any]):\n        return cls(**config)\n\n\n@dataclass\nclass SchedulerConfig:\n    \"\"\"\n    Config for a learning rate scheduler.\n\n    :param name: Name of the scheduler\n    :type name: str\n\n    :param kwargs: Keyword arguments for the scheduler instance (e.g. warmup_steps, T_max)\n    :type kwargs: Dict[str, Any]\n    \"\"\"\n\n    name: str\n    kwargs: Dict[str, Any] = field(default_factory=dict)\n\n    @classmethod\n    def from_dict(cls, config: Dict[str, Any]):\n        return cls(**config)\n\n\n@dataclass\nclass TrainConfig:\n    \"\"\"\n    Config for train job on model.\n\n    :param total_steps: Total number of training steps\n    :type total_steps: int\n\n    :param seq_length: Number of tokens to use as context (max length for tokenizer)\n    :type seq_length: int\n\n    :param epochs: Total number of passes through data\n    :type epochs: int\n\n    :param batch_size: Batch size for training\n    :type batch_size: int\n\n    :param tracker: Tracker to use for logging. Default: \"wandb\"\n    :type tracker: str\n\n    :param checkpoint_interval: Save model every checkpoint_interval steps.\n        Each checkpoint is stored in a sub-directory of the `TrainConfig.checkpoint_dir`\n        directory in the format `checkpoint_dir/checkpoint_{step}`.\n    :type checkpoint_interval: int\n\n    :param eval_interval: Evaluate model every eval_interval steps\n    :type eval_interval: int\n\n    :param pipeline: Pipeline to use for training. One of the registered pipelines present in trlx.pipeline\n    :type pipeline: str\n\n    :param trainer: Trainer to use for training. One of the registered trainers present in trlx.trainer\n    :type trainer: str\n\n    :param trainer_kwargs: Extra keyword arguments for the trainer\n    :type trainer: Dict[str, Any]\n\n    :param project_name: Project name for wandb\n    :type project_name: str\n\n    :param entity_name: Entity name for wandb\n    :type entity_name: str\n\n    :param group_name: Group name for wandb (used for grouping runs)\n    :type group_name: str\n\n    :param checkpoint_dir: Directory to save checkpoints\n    :type checkpoint_dir: str\n\n    :param rollout_logging_dir: Directory to store generated rollouts for use in Algorithm Distillation.\n                                Only used by AcceleratePPOTrainer.\n    :type rollout_logging_dir: Optional[str]\n\n    :param save_best: Save best model based on mean reward\n    :type save_best: bool\n\n    :param seed: Random seed\n    :type seed: int\n    \"\"\"\n\n    total_steps: int\n    seq_length: int\n    epochs: int\n    batch_size: int\n\n    checkpoint_interval: int\n    eval_interval: int\n\n    pipeline: str  # One of the pipelines in framework.pipeline\n    trainer: str  # One of the trainers\n    trainer_kwargs: Dict[str, Any] = field(default_factory=dict)  # Extra keyword arguments for the trainer\n\n    project_name: str = \"trlx\"\n    entity_name: Optional[str] = None\n    group_name: Optional[str] = None\n\n    checkpoint_dir: str = \"ckpts\"\n    rollout_logging_dir: Optional[str] = None\n    save_best: bool = True\n\n    tracker: Optional[str] = \"wandb\"\n    logging_dir: Optional[str] = None\n\n    lora_rank: Optional[int] = 0\n    lora_alpha: Optional[int] = 1\n    lora_train_bias: Optional[str] = \"none\"\n\n    seed: int = 1000\n\n    @classmethod\n    def from_dict(cls, config: Dict[str, Any]):\n        return cls(**config)\n\n\n@dataclass\nclass TRLConfig:\n    \"\"\"\n    Top level config for trlX. Loads configs and can be converted to dictionary.\n    \"\"\"\n\n    method: MethodConfig\n    model: ModelConfig\n    optimizer: OptimizerConfig\n    scheduler: SchedulerConfig\n    tokenizer: TokenizerConfig\n    train: TrainConfig\n\n    @classmethod\n    def load_yaml(cls, yml_fp: str):\n        \"\"\"\n        Load yaml file as TRLConfig.\n\n        :param yml_fp: Path to yaml file\n        :type yml_fp: str\n        \"\"\"\n        with open(yml_fp, mode=\"r\") as file:\n            config = yaml.safe_load(file)\n        return cls.from_dict(config)\n\n    def to_dict(self):\n        \"\"\"\n        Convert TRLConfig to dictionary.\n        \"\"\"\n        data = {\n            \"method\": self.method.__dict__,\n            \"model\": self.model.__dict__,\n            \"optimizer\": self.optimizer.__dict__,\n            \"scheduler\": self.scheduler.__dict__,\n            \"tokenizer\": self.tokenizer.__dict__,\n            \"train\": self.train.__dict__,\n        }\n\n        return data\n\n    def evolve(self, **kwargs) -> \"TRLConfig\":\n        \"\"\"\n        Evolve TRLConfig with new parameters. Can update nested parameters.\n        >>> config = trlx.data.default_configs.default_ilql_config()\n        >>> config = config.evolve(method=dict(gamma=0.99, gen_kwargs=dict(max_new_tokens=100))\n        >>> config.method.gamma\n        0.99\n        \"\"\"\n        return TRLConfig.from_dict(_merge_dicts(self.to_dict(), kwargs))\n\n    @classmethod\n    def from_dict(cls, config: Dict):\n        \"\"\"\n        Convert dictionary to TRLConfig.\n        \"\"\"\n        return cls(\n            method=get_method(config[\"method\"][\"name\"]).from_dict(config[\"method\"]),\n            model=ModelConfig.from_dict(config[\"model\"]),\n            tokenizer=TokenizerConfig.from_dict(config[\"tokenizer\"]),\n            optimizer=OptimizerConfig.from_dict(config[\"optimizer\"]),\n            scheduler=SchedulerConfig.from_dict(config[\"scheduler\"]),\n            train=TrainConfig.from_dict(config[\"train\"]),\n        )\n\n    @classmethod\n    def update(cls, baseconfig: Dict, config: Dict):\n        if not isinstance(baseconfig, Dict):\n            baseconfig = baseconfig.to_dict()\n\n        updates = set()\n        merged = merge(baseconfig, config, updates)\n\n        for param in config:\n            if param not in updates:\n                raise ValueError(f\"parameter {param} is not present in the config (typo or a wrong config)\")\n\n        return cls.from_dict(merged)\n\n    def __str__(self):\n        \"\"\"Returns a human-readable string representation of the config.\"\"\"\n        import json\n\n        return json.dumps(self.to_dict(), indent=4)\n\n\ndef default_ppo_config():\n    return TRLConfig(\n        train=TrainConfig(\n            seq_length=1024,\n            epochs=100,\n            total_steps=10000,\n            batch_size=32,\n            checkpoint_interval=10000,\n            eval_interval=100,\n            pipeline=\"PromptPipeline\",\n            trainer=\"AcceleratePPOTrainer\",\n        ),\n        model=ModelConfig(model_path=\"lvwerra/gpt2-imdb\", num_layers_unfrozen=2),\n        tokenizer=TokenizerConfig(tokenizer_path=\"gpt2\", truncation_side=\"right\"),\n        optimizer=OptimizerConfig(\n            name=\"adamw\", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)\n        ),\n        scheduler=SchedulerConfig(name=\"cosine_annealing\", kwargs=dict(T_max=10000, eta_min=1.0e-4)),\n        method=PPOConfig(\n            name=\"PPOConfig\",\n            num_rollouts=128,\n            chunk_size=128,\n            ppo_epochs=4,\n            init_kl_coef=0.05,\n            target=6,\n            horizon=10000,\n            gamma=1,\n            lam=0.95,\n            cliprange=0.2,\n            cliprange_value=0.2,\n            vf_coef=1,\n            scale_reward=\"ignored\",\n            ref_mean=None,\n            ref_std=None,\n            cliprange_reward=10,\n            gen_kwargs=dict(\n                max_new_tokens=40,\n                top_k=0,\n                top_p=1.0,\n                do_sample=True,\n            ),\n        ),\n    )\n\n\ndef default_ilql_config():\n    return TRLConfig(\n        train=TrainConfig(\n            seq_length=64,\n            batch_size=32,\n            epochs=100,\n            total_steps=1000,\n            checkpoint_interval=1000,\n            eval_interval=100,\n            pipeline=\"PromptPipeline\",\n            trainer=\"AccelerateILQLTrainer\",\n        ),\n        model=ModelConfig(model_path=\"gpt2\", num_layers_unfrozen=-1),\n        tokenizer=TokenizerConfig(tokenizer_path=\"gpt2\", truncation_side=\"right\"),\n        optimizer=OptimizerConfig(\n            name=\"adamw\", kwargs=dict(lr=5.0e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)\n        ),\n        scheduler=SchedulerConfig(\n            name=\"cosine_annealing\", kwargs=dict(T_max=1000, eta_min=5.0e-5)  # train.total_steps\n        ),\n        method=ILQLConfig(\n            name=\"ilqlconfig\",\n            tau=0.7,\n            gamma=0.99,\n            cql_scale=0.1,\n            awac_scale=1,\n            alpha=0.001,\n            beta=0,\n            steps_for_target_q_sync=5,\n            two_qs=True,\n            gen_kwargs=dict(max_new_tokens=56, top_k=20, beta=4, temperature=1.0),\n        ),\n    )\n\n\ndef default_sft_config():\n    return TRLConfig(\n        train=TrainConfig(\n            seq_length=1024,\n            epochs=100,\n            total_steps=1000,\n            batch_size=8,\n            checkpoint_interval=10000,\n            eval_interval=100,\n            pipeline=\"PromptPipeline\",\n            trainer=\"AccelerateSFTTrainer\",\n        ),\n        model=ModelConfig(model_path=\"gpt2\", num_layers_unfrozen=-1),\n        tokenizer=TokenizerConfig(tokenizer_path=\"gpt2\", truncation_side=\"right\"),\n        optimizer=OptimizerConfig(\n            name=\"adamw\", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)\n        ),\n        scheduler=SchedulerConfig(\n            name=\"cosine_annealing\", kwargs=dict(T_max=10000, eta_min=1.0e-4)  # train.total_steps\n        ),\n        method=SFTConfig(\n            name=\"sftconfig\",\n            gen_kwargs=dict(max_new_tokens=40, top_k=0, top_p=1.0, do_sample=True),\n        ),\n    )\n\n\ndef get_train_ds_config(global_batch_size=32,\n                        micro_batch_size=4,\n                        gradient_accumulation_steps=8,\n                        offload=False,\n                        stage=2,\n                        enable_hybrid_engine=False,\n                        inference_tp_size=1,\n                        release_inference_cache=False,\n                        pin_parameters=True,\n                        tp_gather_partition_size=8,\n                        max_out_tokens=512):\n\n    device = \"cpu\" if offload else \"none\"\n    zero_opt_dict = {\n        \"stage\": stage,\n        \"offload_param\": {\n            \"device\": device\n        },\n        \"offload_optimizer\": {\n            \"device\": device\n        },\n        \"stage3_param_persistence_threshold\": 1e4,\n        \"stage3_max_live_parameters\": 3e7,\n        \"stage3_prefetch_bucket_size\": 3e7,\n        \"memory_efficient_linear\": False\n    }\n    return {\n        \"train_batch_size\": global_batch_size,\n        \"train_micro_batch_size_per_gpu\": micro_batch_size,\n        \"gradient_accumulation_steps\": gradient_accumulation_steps,\n        \"steps_per_print\": 10,\n        \"zero_optimization\": zero_opt_dict,\n        \"fp16\": {\n            \"enabled\": True,\n            \"loss_scale_window\": 100\n        },\n        \"gradient_clipping\": 1.0,\n        \"prescale_gradients\": False,\n        \"wall_clock_breakdown\": False,\n        \"hybrid_engine\": {\n            \"enabled\": enable_hybrid_engine,\n            \"max_out_tokens\": max_out_tokens,\n            \"inference_tp_size\": inference_tp_size,\n            \"release_inference_cache\": release_inference_cache,\n            \"pin_parameters\": pin_parameters,\n            \"tp_gather_partition_size\": tp_gather_partition_size,\n        }\n    }\n\n\ndef get_eval_ds_config(global_batch_size=32, micro_batch_size=4, gradient_accumulation_steps=8, offload=False, stage=0):\n    device = \"cpu\" if offload else \"none\"\n    zero_opt_dict = {\n        \"stage\": stage,\n        \"stage3_param_persistence_threshold\": 1e4,\n        \"offload_param\": {\n            \"device\": device\n        },\n        \"memory_efficient_linear\": False\n    }\n    return {\n        \"train_batch_size\": global_batch_size,\n        \"train_micro_batch_size_per_gpu\": micro_batch_size,\n        \"gradient_accumulation_steps\": gradient_accumulation_steps,\n        \"steps_per_print\": 10,\n        \"zero_optimization\": zero_opt_dict,\n        \"fp16\": {\n            \"enabled\": True\n        },\n        \"gradient_clipping\": 1.0,\n        \"prescale_gradients\": False,\n        \"wall_clock_breakdown\": False\n    }\n"
  },
  {
    "path": "src/utils/file_utils.py",
    "content": "\nimport math\nimport os\nimport random\nimport numpy as np\nimport torch\n\nfrom numbers import Number\nfrom pynvml import *\n\nfrom src.utils.logger import logger\n\n\ndef set_seed(seed_val=42):\n    random.seed(seed_val)\n    np.random.seed(seed_val)\n    torch.manual_seed(seed_val)\n    torch.cuda.manual_seed_all(seed_val)\n\n\ndef print_rank_0(*message):\n    \"\"\"\n    Print only once from the main rank\n    \"\"\"\n    if os.environ.get(\"RANK\", \"0\") == \"0\":\n        logger.info(*message)\n\n\ndef significant(x: Number, ndigits=2) -> Number:\n    \"\"\"\n    Cut the number up to its `ndigits` after the most significant\n    \"\"\"\n    if isinstance(x, torch.Tensor):\n        x = x.item()\n\n    if not isinstance(x, Number) or math.isnan(x) or x == 0:\n        return x\n\n    return round(x, ndigits - int(math.floor(math.log10(abs(x)))))\n\n#\n# def set_seed(seed: int):\n#     \"\"\"\n#     Sets seeds across package dependencies for reproducibility.\n#     \"\"\"\n#     seed += int(os.environ.get(\"RANK\", 0))\n#     random.seed(seed)\n#     np.random.seed(seed)\n#     torch.manual_seed(seed)\n#     torch.cuda.manual_seed(seed)\n\n\ndef print_gpu_utilization(prefix: str = \"\", index: int = 0, only_rank_0: bool = True):\n    nvmlInit()\n    handle = nvmlDeviceGetHandleByIndex(index)\n    info = nvmlDeviceGetMemoryInfo(handle)\n    memory_used = info.used / 1024**3\n    if only_rank_0:\n        if index == 0:\n            logger.info(f\"[{prefix}] GPU-{index} memory occupied: {memory_used:.2f} GB\")\n    else:\n        logger.info(f\"[{prefix}] GPU-{index} memory occupied: {memory_used:.2f} GB\")\n\n\ndef print_gpu_utilization_torch(prefix: str = \"\", index: int = 0, only_rank_0: bool = True):\n    memory_allocated = torch.cuda.memory_allocated() / 1024 ** 3\n    max_memory_allocated = torch.cuda.max_memory_allocated() / 1024 ** 3\n    memory_reserved = torch.cuda.memory_reserved() / 1024 ** 3\n    # max_memory_reserved = torch.cuda.max_memory_reserved() / 1024 ** 3\n    if only_rank_0:\n        if index == 0:\n            logger.info(f\"[{prefix}] GPU-{index}: memory allocated: {memory_allocated:.2f} GB, \"\n                        f\"max memory allocated: {max_memory_allocated:.2f} GB, \"\n                        f\"memory reserved: {memory_reserved:.2f} GB, \"\n                        # f\"max memory reserved: {max_memory_allocated:.2f} GB\"\n                        )\n    else:\n        logger.info(f\"[{prefix}] GPU-{index}: memory allocated: {memory_allocated:.2f} GB, \"\n                    f\"max memory allocated: {max_memory_allocated:.2f} GB, \"\n                    f\"memory reserved: {memory_reserved:.2f} GB, \"\n                    # f\"max memory reserved: {max_memory_reserved:.2f} GB\"\n                    )\n\n\ndef print_trainable_parameters(model):\n    \"\"\"\n    Prints the number of trainable parameters in the model.\n    \"\"\"\n    trainable_params = 0\n    all_param = 0\n    for _, param in model.named_parameters():\n        all_param += param.numel()\n        if param.requires_grad:\n            trainable_params += param.numel()\n    logger.info(\n        f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}\"\n    )"
  },
  {
    "path": "src/utils/loading.py",
    "content": "import types\nimport torch\nimport glob\nfrom typing import Callable, Dict, Tuple\n\n# Register load pipelines via module import\nfrom transformers import (\n    AutoConfig,\n    AutoTokenizer,\n    AutoModelForCausalLM,\n    AutoModelForSeq2SeqLM,\n    BitsAndBytesConfig,\n    PreTrainedModel,\n    PreTrainedTokenizer,\n)\nfrom transformers.models.llama import LlamaModel, LlamaForCausalLM\nfrom accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch\nfrom accelerate.utils import get_balanced_memory\nfrom peft import (\n    prepare_model_for_kbit_training,\n    LoraConfig,\n    get_peft_model\n)\n\nfrom src.utils import print_trainable_parameters\nfrom src.utils.modeling_utils import _prepare_decoder_attention_mask\nfrom src.data.pipeline import _DATAPIPELINE\nfrom src.models.trainer import _TRAINERS, register_trainer\n# from trlx.pipeline.offline_pipeline import PromptPipeline\n# from trlx.trainer.accelerate_ilql_trainer import AccelerateILQLTrainer\n# from trlx.trainer.accelerate_ppo_trainer import AcceleratePPOTrainer\n# from trlx.trainer.accelerate_sft_trainer import AccelerateSFTTrainer\n\ntry:\n    from src.models.trainer import NeMoILQLTrainer\nexcept ImportError:\n    # NeMo is not installed\n    def _trainer_unavailble(name):\n        def log_error(*args, **kwargs):\n            raise ImportError(f\"Unable to import NeMo so {name} is unavailable\")\n\n        return register_trainer(name)(log_error)\n\n    _trainer_unavailble(\"NeMoILQLTrainer\")\n\n\ndef prepare_decoder_attention_mask(self, *args, **kwargs):\n    return _prepare_decoder_attention_mask(*args, **kwargs)\n\n\ndef chatglm_auto_configure_device_map(num_gpus: int, model_name: str, local_rank: int = 0) -> Dict[str, int]:\n    # transformer.word_embeddings 占用1层\n    # transformer.final_layernorm 和 lm_head 占用1层\n    # transformer.layers 占用 28 层\n    # 总共30层分配到num_gpus张卡上\n    num_hidden_layers = 28\n    layers_per_gpu = (num_hidden_layers+2) // num_gpus\n    layer_prefix = 'transformer'\n\n    # bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError\n    # windows下 model.device 会被设置成 transformer.word_embeddings.device\n    # linux下 model.device 会被设置成 lm_head.device\n    # 在调用chat或者stream_chat时,input_ids会被放到model.device上\n    # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError\n    # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上\n\n    encode = \"\"\n    if 'chatglm2' in model_name:\n        device_map = {\n            f\"{layer_prefix}.embedding.word_embeddings\": local_rank,\n            f\"{layer_prefix}.rotary_pos_emb\": local_rank,\n            f\"{layer_prefix}.output_layer\": local_rank,\n            f\"{layer_prefix}.encoder.final_layernorm\": local_rank,\n            f\"base_model.model.output_layer\": local_rank,\n        }\n        encode = \".encoder\"\n    else:\n        device_map = {\n            f'{layer_prefix}.word_embeddings': local_rank,\n            f'{layer_prefix}.final_layernorm': local_rank,\n            'lm_head': local_rank,\n            f'base_model.model.lm_head': local_rank,\n        }\n    used = 2\n    gpu_target = 0\n    # TODO: Assuming CUDA device index is consecutive, e.g. cuda:0, cuda:1, cuda:2\n    for i in range(num_hidden_layers):\n        if used >= layers_per_gpu + (gpu_target % 2):\n            gpu_target += 1\n            gpu_target %= num_gpus\n            used = 0\n        device_map[f'{layer_prefix}{encode}.layers.{i}'] = gpu_target + local_rank\n        used += 1\n\n    return device_map\n\n\ndef llama_and_baichuan_auto_configure_device_map(num_gpus: int, model_name: str, local_rank: int = 0) -> Dict[str, int]:\n    layer_prefix = 'model'\n    # model.embed_tokens 占用1层\n    # model.norm 和 lm_head 占用1层\n    # model.layers 占用 num_hidden_layers 层\n    # 总共num_hidden_layers+2层分配到num_gpus张卡上\n    if \"7b\" in model_name.lower():\n        num_hidden_layers = 32\n    elif \"13b\" in model_name.lower():\n        num_hidden_layers = 40\n    else:\n        raise ValueError(f\"Only supports baichuan-7B, baichuan-13B, llama-7B and llama-13B, but {model_name} is provided\")\n\n    layers_per_gpu = (num_hidden_layers+2) // num_gpus\n    device_map = {\n        f'{layer_prefix}.embed_tokens':  local_rank,\n        f'{layer_prefix}.norm': local_rank,\n        'lm_head': local_rank,\n        f'base_model.model.lm_head': local_rank,\n    }\n    used = 2\n    gpu_target = 0\n    # TODO: Assuming CUDA device index is consecutive, e.g. cuda:0, cuda:1, cuda:2\n    for i in range(num_hidden_layers):\n        if used >= layers_per_gpu + (gpu_target % 2):\n            gpu_target += 1\n            gpu_target %= num_gpus\n            used = 0\n        device_map[f'{layer_prefix}.layers.{i}'] = gpu_target + local_rank\n        used += 1\n\n    return device_map\n\n\ndef load_params_8bit_or_4bit(args, model: PreTrainedModel) -> Dict:\n    # init bnb config for quantization\n    bf16 = torch.cuda.get_device_capability()[0] >= 8\n    if bf16:\n        bnb_4bit_compute_dtype = torch.bfloat16\n    else:\n        bnb_4bit_compute_dtype = torch.float16\n    bnb_config = BitsAndBytesConfig(\n        load_in_8bit=args.bits == 8,\n        load_in_4bit=args.bits == 4,\n        bnb_4bit_use_double_quant=True,\n        bnb_4bit_quant_type=\"nf4\",\n        bnb_4bit_compute_dtype=bnb_4bit_compute_dtype\n    )\n    params = {\n        \"low_cpu_mem_usage\": True,\n        'quantization_config': bnb_config\n    }\n    # infer device map\n    if args.multi_card:\n        max_memory = get_balanced_memory(model, dtype=torch.int8, low_zero=False,\n                                         no_split_module_classes=model._no_split_modules)\n        params['device_map'] = infer_auto_device_map(\n            model,\n            dtype=torch.int8,\n            no_split_module_classes=model._no_split_modules,\n            max_memory=max_memory\n        )\n    else:\n        params['device_map'] = {\"\": args.local_rank}\n\n    return params\n\n\ndef load_tokenizer_and_model(args, with_trainer: bool = True) -> Tuple[PreTrainedTokenizer, PreTrainedModel, int]:\n    # load tokenizer\n    tokenizer_path = args.tokenizer_path if hasattr(args, \"tokenizer_path\") else args.model_name_or_path\n    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)\n\n    # set eos token\n    if \"chatglm3\" in args.model_name_or_path.lower():\n        eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command(\"<|user|>\"), tokenizer.get_command(\"<|observation|>\")]\n    elif \"chatglm2\" in args.model_name_or_path.lower():\n        eos_token_id = tokenizer.get_command(\"eop\") if args.checkpoint is not None else tokenizer.get_command(\"<eos>\")\n    elif \"chatglm1_1\" in args.model_name_or_path.lower():\n        eos_token_id = tokenizer.eos_token_id\n    elif \"chatglm\" in args.model_name_or_path.lower():\n        eos_token_id = tokenizer.eop_token_id\n    elif \"baichuan\" in args.model_name_or_path.lower():\n        eos_token_id = tokenizer.bos_token_id if args.checkpoint is not None else tokenizer.eos_token_id\n        tokenizer.pad_token_id = tokenizer.eos_token_id\n    else:\n        eos_token_id = tokenizer.eos_token_id\n        if tokenizer.pad_token_id is None:\n            tokenizer.pad_token_id = tokenizer.eos_token_id\n\n    # load model\n    if \"chatglm\" in args.model_name_or_path.lower():\n        model_class = AutoModelForSeq2SeqLM\n    else:\n        model_class = AutoModelForCausalLM\n\n    if torch.cuda.is_available():\n        dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16\n    else:\n        dtype = torch.float32\n    params = {\n        \"trust_remote_code\": True,\n        \"torch_dtype\": dtype,\n        \"load_in_8bit\": hasattr(args, \"bits\") and args.bits == 8,\n        \"load_in_4bit\": hasattr(args, \"bits\") and args.bits == 4,\n        # \"quantization_config\": bnb_config,\n    }\n    if with_trainer:\n        params[\"device_map\"] = args.device_map\n        params[\"low_cpu_mem_usage\"] = args.low_cpu_mem_usage\n    model = model_class.from_pretrained(args.model_name_or_path,\n                                        **params)\n    # # cpu\n    # if not torch.cuda.is_available():\n    #     model = model_class.from_pretrained(args.model_name_or_path,\n    #                                         trust_remote_code=True)\n    # # 8bit or 4bit\n    # elif hasattr(args, \"bits\") and args.bits in [4, 8]:\n    #     config = AutoConfig.from_pretrained(args.model_name_or_path, trust_remote_code=True)\n    #     model = model_class.from_config(config, trust_remote_code=True)\n    #     params = load_params_8bit_or_4bit(args, model)\n    #     model = model_class.from_pretrained(args.model_name_or_path,\n    #                                         trust_remote_code=True,\n    #                                         **params)\n    #     if args.do_train:\n    #         if args.gradient_checkpointing:\n    #             model.gradient_checkpointing_enable()\n    #         model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing)\n    # # multi gpu card\n    # elif hasattr(args, \"multi_card\") and args.multi_card:\n    #     with init_empty_weights():\n    #         config = AutoConfig.from_pretrained(args.model_name_or_path, trust_remote_code=True)\n    #         model = model_class.from_config(config, trust_remote_code=True).half()\n    #     model.tie_weights()\n    #     if \"llama\" in args.model_name_or_path.lower() or \\\n    #         \"baichuan\" in args.model_name_or_path.lower() or \\\n    #         \"vicuna\" in args.model_name_or_path.lower():\n    #         device_map = llama_and_baichuan_auto_configure_device_map(\n    #             torch.cuda.device_count(),\n    #             args.model_name_or_path.lower(),\n    #             args.local_rank\n    #         )\n    #     elif \"chatglm\" in args.model_name_or_path.lower():\n    #         device_map = chatglm_auto_configure_device_map(\n    #             torch.cuda.device_count(),\n    #             args.model_name_or_path.lower(),\n    #             args.local_rank\n    #         )\n    #     else:\n    #         #     max_memory = get_balanced_memory(model, dtype=torch.float16, low_zero=False,\n    #         #                                      no_split_module_classes=model._no_split_modules)\n    #         #     device_map = infer_auto_device_map(model, dtype=torch.float16, max_memory=max_memory,\n    #         #                                        no_split_module_classes=model._no_split_modules)\n    #         device_map = \"auto\"\n    #\n    #     model = load_checkpoint_and_dispatch(model,\n    #                                          checkpoint=args.model_name_or_path,\n    #                                          device_map=device_map,\n    #                                          no_split_module_classes=model._no_split_modules,\n    #                                          dtype=torch.float16)\n    # # single gpu card\n    # else:\n    #     model = model_class.from_pretrained(args.model_name_or_path,\n    #                                         trust_remote_code=True,\n    #                                         torch_dtype=torch.float16,\n    #                                         device_map={\"\": args.local_rank})\n\n    # post-loading operations\n    if hasattr(args, \"concat_samples\") and args.concat_samples and isinstance(model, LlamaForCausalLM):\n        prepare_decoder_attention_mask_method = types.MethodType(prepare_decoder_attention_mask, model.model)\n        setattr(model.model, \"_prepare_decoder_attention_mask\", prepare_decoder_attention_mask_method)\n        # funcType = type(LlamaModel._prepare_decoder_attention_mask)\n        # model.model._prepare_decoder_attention_mask = funcType(prepare_decoder_attention_mask, model.model, LlamaModel)\n    if \"pangu\" in args.model_name_or_path.lower():\n        model.resize_token_embeddings(tokenizer.vocab_size)\n    if hasattr(args, \"bits\") and args.bits in [4, 8] and args.do_train:\n        if args.gradient_checkpointing:\n            model.gradient_checkpointing_enable()\n        model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing)\n\n    # init peft model (if necessary)\n    if hasattr(args, \"lora_rank\") and args.lora_rank > 0:\n        model = to_peft_model(args, model)\n\n    return tokenizer, model, eos_token_id\n\n\ndef to_peft_model(args, model: PreTrainedModel) -> PreTrainedModel:\n    if \"llama\" in args.model_name_or_path.lower() or \\\n        \"vicuna\" in args.model_name_or_path.lower() or \\\n        \"billa\" in args.model_name_or_path.lower() or \\\n        \"atomgpt\" in args.model_name_or_path.lower() or \\\n        \"pangu\" in args.model_name_or_path.lower():\n        target_modules = [\"q_proj\", \"k_proj\", \"v_proj\"]\n        task_type = \"CAUSAL_LM\"\n    elif \"baichuan\" in args.model_name_or_path.lower():\n        target_modules = [\"W_pack\"]\n        task_type = \"CAUSAL_LM\"\n    elif \"bloom\" in args.model_name_or_path.lower() or \"tigerbot\" in args.model_name_or_path.lower():\n        target_modules = [\"query_key_value\"]\n        task_type = \"CAUSAL_LM\"\n    elif \"glm\" in args.model_name_or_path.lower():\n        target_modules = [\"query_key_value\"]\n        task_type = \"SEQ_2_SEQ_LM\"\n    else:\n        raise ValueError(f\"Unsupported model name: {args.model_name_or_path}\")\n\n    config = LoraConfig(\n        r=args.lora_rank,\n        lora_alpha=args.lora_alpha,\n        target_modules=target_modules,\n        lora_dropout=0.05,\n        bias=args.lora_train_bias,\n        task_type=task_type\n    )\n    model.enable_input_require_grads()\n    model = get_peft_model(model, config)\n    print_trainable_parameters(model)\n\n    return model\n\n\ndef load_checkpoint(args, model: PreTrainedModel, strict: bool = True) -> None:\n    checkpoints = glob.glob(args.checkpoint.replace(\"star\", \"*\"))\n    st = dict()\n    for checkpoint in checkpoints:\n        st.update(torch.load(checkpoint, map_location=\"cpu\"))\n    model.load_state_dict(st, strict=strict)\n    del st\n\n\ndef get_trainer(name: str) -> Callable:\n    \"\"\"\n    Return constructor for specified RL model trainer\n    \"\"\"\n    name = name.lower()\n    if name in _TRAINERS:\n        return _TRAINERS[name]\n    else:\n        raise Exception(\"Error: Trying to access a trainer that has not been registered\")\n\n\ndef get_pipeline(name: str) -> Callable:\n    \"\"\"\n    Return constructor for specified pipeline\n    \"\"\"\n    name = name.lower()\n    if name in _DATAPIPELINE:\n        return _DATAPIPELINE[name]\n    else:\n        raise Exception(\"Error: Trying to access a pipeline that has not been registered\")\n"
  },
  {
    "path": "src/utils/logger.py",
    "content": "\nimport os\nimport logging.config\n\nfrom src.utils.config import RESOURCE_PATH\n\nlogging.config.fileConfig(os.path.join(RESOURCE_PATH, \"config\", \"log.conf\"), disable_existing_loggers=False)\nlogger = logging.getLogger(__name__)\n"
  },
  {
    "path": "src/utils/method_configs.py",
    "content": "import sys\nimport torch\nimport torch.nn.functional as F\n\nfrom functools import reduce\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, Optional, Tuple\nfrom torchtyping import TensorType\n\nfrom src.utils.modeling_utils import (\n    flatten_dict,\n    get_tensor_stats,\n    whiten,\n)\n\n# specifies a dictionary of method configs\n_METHODS: Dict[str, Any] = {}  # registry\n\n\ndef register_method(name):\n    \"\"\"Decorator used register a method config\n    Args:\n        name: Name of the method\n    \"\"\"\n\n    def register_class(cls, name):\n        _METHODS[name] = cls\n        setattr(sys.modules[__name__], name, cls)\n        return cls\n\n    if isinstance(name, str):\n        name = name.lower()\n        return lambda c: register_class(c, name)\n\n    cls = name\n    name = cls.__name__\n    register_class(cls, name.lower())\n\n    return cls\n\n\n@dataclass\n@register_method\nclass MethodConfig:\n    \"\"\"\n    Config for a certain RL method.\n\n    :param name: Name of the method\n    :type name: str\n    \"\"\"\n\n    name: str\n\n    @classmethod\n    def from_dict(cls, config: Dict[str, Any]):\n        return cls(**config)\n\n\ndef get_method(name: str) -> MethodConfig:\n    \"\"\"\n    Return constructor for specified method config\n    \"\"\"\n    name = name.lower()\n    if name in _METHODS:\n        return _METHODS[name]\n    else:\n        raise Exception(\"Error: Trying to access a method that has not been registered\")\n\n\n# PPO Configs\n@dataclass\n@register_method\nclass PPOConfig(MethodConfig):\n    \"\"\"\n    Config for PPO method\n\n    :param ppo_epochs: Number of updates per batch\n    :type ppo_epochs: int\n\n    :param num_rollouts: Number  of experiences to observe before learning\n    :type num_rollouts: int\n\n    :param init_kl_coef: Initial value for KL coefficient\n    :type init_kl_coef: float\n\n    :param target: Target value for KL coefficient\n    :type target: float\n\n    :param horizon: Number of steps for KL coefficient to reach target\n    :type horizon: int\n\n    :param gamma: Discount factor\n    :type gamma: float\n\n    :param lam: GAE lambda\n    :type lam: float\n\n    :param cliprange: Clipping range for PPO policy loss (1 - cliprange, 1 + cliprange)\n    :type cliprange: float\n\n    :param cliprange_value: Clipping range for predicted values\n                            (observed values - cliprange_value, observed values + cliprange_value)\n    :type cliprange_value: float\n\n    :param vf_coef: Value loss scale w.r.t policy loss\n    :type vf_coef: float\n\n    :param gen_kwargs: Additioanl kwargs for the generation\n    :type gen_kwargs: Dict[str, Any]\n\n    :param gen_experience_kwargs: if this is not None, then the experience is generated using this\n    :type gen_experience_kwargs: Dict[str, Any]\n    \"\"\"\n\n    ppo_epochs: int\n    num_rollouts: int\n    chunk_size: int\n    init_kl_coef: float\n    target: float\n    horizon: int\n    gamma: float\n    lam: float\n    cliprange: float\n    cliprange_value: float\n    vf_coef: float\n    scale_reward: Optional[str]\n    ref_mean: Optional[float]\n    ref_std: Optional[float]\n    cliprange_reward: float\n    gen_kwargs: dict\n    gen_experience_kwargs: Optional[dict] = None\n\n    def get_advantages_and_returns(\n            self,\n            values: TensorType[\"batch_size\", \"response_size\"],\n            rewards: TensorType[\"batch_size\", \"response_size\"],\n            response_length: int,\n            use_whitening: Optional[bool] = True,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"Function that computes advantages and returns from rewards and values.\n        Calculated as in the original PPO paper: https://arxiv.org/abs/1707.06347\n        Note that rewards may include a KL divergence loss term.\n\n        Advantages looks like this:\n        Adv1 =  R1 + γ * λ * R2     + γ^2 * λ^2 * R3       + ...\n              - V1 + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ...\n\n        Returns looks like this:\n        Ret1 =  R1 + γ * λ * R2     + γ^2 * λ^2 * R3       + ...\n                   + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ...\n\n        Args:\n            values: Tensor of shape (batch_size, response_size)\n            rewards: Tensor of shape (batch_size, response_size)\n            response_length: Length of the response sequence\n            use_whitening: Whether to use whitening (ie. normalize advantages) or not\n        \"\"\"\n        lastgaelam = 0\n        advantages_reversed = []\n        for t in reversed(range(response_length)):\n            nextvalues = values[:, t + 1] if t < response_length - 1 else 0.0\n            delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]\n            lastgaelam = delta + self.gamma * self.lam * lastgaelam\n            advantages_reversed.append(lastgaelam)\n        advantages = torch.stack(advantages_reversed[::-1], dim=1)\n        returns = advantages + values\n        if use_whitening:\n            advantages = whiten(advantages)\n        return advantages.detach(), returns\n\n    def loss(\n            self,\n            logprobs: TensorType[\"batch_size\", \"response_size\"],\n            values: TensorType[\"batch_size\", \"response_size\"],\n            old_logprobs: TensorType[\"batch_size\", \"response_size\"],\n            old_values: TensorType[\"batch_size\", \"response_size\"],\n            advantages: TensorType[\"batch_size\", \"response_size\"],\n            returns: TensorType[\"batch_size\", \"response_size\"],\n            mask: TensorType[\"batch_size\", \"response_size\"],\n    ):\n        \"\"\"PPO objective function.\n        References:\n        - https://stable-baselines.readthedocs.io/en/master/modules/ppo2.html\n        \"\"\"\n        print(f\"[ppo loss] values shape: {values.shape}, old_values shape: {old_values.shape}\")\n        values_clipped = torch.clamp(\n            values,\n            old_values - self.cliprange_value,\n            old_values + self.cliprange_value,\n            )\n        n = mask.sum()\n\n        vf_loss1 = (values - returns) ** 2\n        vf_loss2 = (values_clipped - returns) ** 2\n        vf_loss = 0.5 * torch.sum(torch.max(vf_loss1, vf_loss2) * mask) / n\n        vf_clipfrac = torch.sum((vf_loss2 > vf_loss1).float() * mask) / n\n\n        log_ratio = (logprobs - old_logprobs) * mask\n        ratio = torch.exp(log_ratio)\n        # Unbiased KL-div estimates (`k3`). Ref: http://joschu.net/blog/kl-approx.html\n        with torch.no_grad():\n            approx_kl = torch.mean((ratio - 1) - log_ratio)\n\n        pg_loss1 = -advantages * ratio\n        pg_loss2 = -advantages * torch.clamp(\n            ratio,\n            1.0 - self.cliprange,\n            1.0 + self.cliprange,\n            )\n        pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / n\n        pg_clipfrac = torch.sum((pg_loss2 > pg_loss1).float() * mask) / n\n\n        loss = pg_loss + self.vf_coef * vf_loss\n\n        stats = dict(\n            losses=dict(\n                total_loss=loss.item(),\n                policy_loss=pg_loss.item(),\n                value_loss=vf_loss.item(),\n            ),\n            values=dict(\n                get_tensor_stats(values, mask, n),\n                values_error=torch.sum(((values - returns) * mask) ** 2) / n,\n                clipfrac=vf_clipfrac,\n            ),\n            old_values=get_tensor_stats(old_values, mask, n),\n            returns=get_tensor_stats(returns, mask, n),\n            policy=dict(approx_kl=approx_kl.item(), clipfrac=pg_clipfrac.item()),\n            ratio=(ratio * mask).sum() / n,\n            padding_percentage=n / mask.numel(),\n        )\n\n        return loss, flatten_dict(stats)\n\n\n@dataclass\n@register_method\nclass SFTConfig(MethodConfig):\n    \"\"\"\n    Config for SFT training\n\n    :param gen_kwargs: kwargs for generation\n    :type gen_kwargs: Dict[str, Any]\n    \"\"\"\n\n    gen_kwargs: dict\n\n\n@dataclass\n@register_method\nclass ILQLConfig(MethodConfig):\n    tau: float\n    gamma: float\n    cql_scale: float\n    awac_scale: float\n    alpha: float\n    beta: float\n    steps_for_target_q_sync: float\n    two_qs: bool\n    gen_kwargs: dict\n\n    def loss(self, outputs, labels):\n        logits, (qs, target_qs, vs) = outputs\n        terminal_mask = labels.dones[:, :-1]\n        n_nonterminal = max(1, terminal_mask.sum())\n        # check type of labels\n        if isinstance(labels, ILQLBatch):\n            actions = labels.input_ids[:, 1:].gather(dim=1, index=labels.actions_ixs).unsqueeze(-1)\n        else:\n            actions = labels.decoder_input_ids[:, 1:].unsqueeze(-1)\n        nactions = actions.shape[1]\n        bsize, _, dsize = logits.shape\n\n        Q = [q.gather(-1, actions).squeeze(-1) for q in qs]\n        targetQs = [q.gather(-1, actions).squeeze(-1).detach() for q in target_qs]\n        targetQ = reduce(torch.minimum, targetQs)\n\n        # values of current states\n        V = vs[:, :-1].squeeze()\n        # values of next states\n        Vnext = vs[:, 1:].squeeze() * labels.dones[:, 1:]\n        # target to fit Q\n        Q_ = labels.rewards + self.gamma * Vnext.detach()\n\n        loss_qs = [((Qi - Q_) * terminal_mask).pow(2).sum() / n_nonterminal for Qi in Q]\n        loss_q = sum(loss_qs)\n\n        targetQ = targetQ.detach()\n\n        loss_v = (\n                         (\n                                 (targetQ >= V).int() * self.tau * (targetQ - V).pow(2)\n                                 + (targetQ < V).int() * (1 - self.tau) * (targetQ - V).pow(2)\n                         )\n                         * terminal_mask\n                 ).sum() / n_nonterminal\n\n        def cql_loss(q):\n            loss = F.cross_entropy(q.reshape(-1, dsize), actions.reshape(-1), reduction=\"none\")\n            loss = loss.reshape(bsize, nactions) * terminal_mask\n            loss = loss.sum() / n_nonterminal\n            return loss\n\n        loss_cql = sum(cql_loss(q) for q in qs)\n\n        # select logits from continuations\n        action_logits = batched_index_select(logits, labels.actions_ixs, dim=1)\n        cross_entropy = F.cross_entropy(\n            action_logits.reshape(-1, dsize),\n            actions.reshape(-1),\n            reduction=\"none\",\n        ).reshape(bsize, nactions)\n\n        with torch.no_grad():\n            awac_weight = torch.exp(self.beta * (targetQ - V))\n\n        loss_awac = torch.sum(cross_entropy * awac_weight * terminal_mask) / n_nonterminal\n        loss = loss_q + loss_v + self.cql_scale * loss_cql + self.awac_scale * loss_awac\n\n        stats = dict(\n            losses=dict(\n                loss=loss.item(),\n                loss_q=loss_q.item(),\n                loss_v=loss_v.item(),\n                loss_cql=loss_cql.item(),\n                loss_awac=loss_awac.item(),\n            ),\n            values=get_tensor_stats(V, terminal_mask, n_nonterminal),\n            qvalues={str(ix): get_tensor_stats(Q[ix], terminal_mask, n_nonterminal) for ix in range(len(Q))},\n            awac_weight=get_tensor_stats(awac_weight, terminal_mask, n_nonterminal),\n        )\n\n        return loss, flatten_dict(stats)\n"
  },
  {
    "path": "src/utils/modeling_utils.py",
    "content": "\nimport functools\nfrom typing import Any, Dict, List, MutableMapping, Tuple, Union, Optional\n\nimport os\nimport subprocess\nimport time\nimport numpy as np\nimport re\nimport shutil\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport transformers\nimport deepspeed\n\nfrom pathlib import Path\nfrom deepspeed.runtime.zero.partition_parameters import ZeroParamStatus\nfrom dataclasses import is_dataclass\nfrom enum import Enum\nfrom accelerate import Accelerator\nfrom torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR\nfrom transformers import PreTrainedTokenizer\n\ntry:\n    from opendelta import (\n        AdapterModel,\n        BitFitModel,\n        LoraModel,\n        PrefixModel,\n        SoftPromptModel,\n    )\n\n    HAS_OPENDELTA = True\nexcept ModuleNotFoundError:\n    HAS_OPENDELTA = False\n\n\n# Copied from transformers.models.bart.modeling_bart._make_causal_mask\ndef _make_causal_mask(\n        input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0\n):\n    \"\"\"\n    Make causal mask used for bi-directional self-attention.\n    \"\"\"\n    bsz, tgt_len = input_ids_shape\n    mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)\n    mask_cond = torch.arange(mask.size(-1), device=device)\n    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)\n    mask = mask.to(dtype)\n\n    if past_key_values_length > 0:\n        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)\n    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)\n\n\n# Copied from transformers.models.bart.modeling_bart._expand_mask\ndef _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):\n    \"\"\"\n    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\n    \"\"\"\n    bsz, src_len = mask.size()\n    tgt_len = tgt_len if tgt_len is not None else src_len\n\n    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\n\n    inverted_mask = 1.0 - expanded_mask\n\n    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)\n\n\ndef _prepare_decoder_attention_mask(attention_mask, input_shape, input_embeds, past_key_values_length):\n    # create causal mask\n    if attention_mask is not None:\n        # attention mask is already given in 3d shape\n        if len(attention_mask.shape) == 3:\n            return attention_mask.unsqueeze(1).to(input_embeds.device)\n\n    # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n    combined_attention_mask = None\n    if input_shape[-1] > 1:\n        combined_attention_mask = _make_causal_mask(\n            input_shape,\n            input_embeds.dtype,\n            device=input_embeds.device,\n            past_key_values_length=past_key_values_length,\n        )\n\n    if attention_mask is not None:\n        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n        expanded_attn_mask = _expand_mask(attention_mask, input_embeds.dtype, tgt_len=input_shape[-1]).to(\n            input_embeds.device\n        )\n        combined_attention_mask = (\n            expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask\n        )\n\n    return combined_attention_mask\n\n\ndef get_distributed_config(accelerator: Accelerator):\n    \"\"\"\n    Return accelerator distributed config\n    \"\"\"\n\n    dist_config = {\n        \"mixed_precision\": accelerator.mixed_precision,\n        \"num_gpus\": accelerator.num_processes,\n    }\n\n    if accelerator.state.deepspeed_plugin is not None:\n        ds_plugin = accelerator.state.deepspeed_plugin\n        dist_config.update(\n            {\n                \"gradient_accumulation_steps\": ds_plugin.gradient_accumulation_steps,\n                \"gradient_clipping\": ds_plugin.gradient_clipping,\n                \"zero_stage\": ds_plugin.zero_stage,\n                \"offload_optimizer_device\": ds_plugin.offload_optimizer_device,\n                \"offload_param_device\": ds_plugin.offload_param_device,\n            }\n        )\n\n    return dist_config\n\n\nclass OptimizerName(str, Enum):\n    \"\"\"Supported optimizer names\"\"\"\n\n    ADAM: str = \"adam\"\n    ADAMW: str = \"adamw\"\n    ADAM_8BIT_BNB: str = \"adam_8bit_bnb\"\n    ADAMW_8BIT_BNB: str = \"adamw_8bit_bnb\"\n    SGD: str = \"sgd\"\n\n\ndef get_optimizer_class(name: OptimizerName):\n    \"\"\"\n    Returns the optimizer class with the given name\n\n    Args:\n        name (str): Name of the optimizer as found in `OptimizerNames`\n    \"\"\"\n    if name == OptimizerName.ADAM:\n        return torch.optim.Adam\n    if name == OptimizerName.ADAMW:\n        return torch.optim.AdamW\n    if name == OptimizerName.ADAM_8BIT_BNB.value:\n        try:\n            from bitsandbytes.optim import Adam8bit\n\n            return Adam8bit\n        except ImportError:\n            raise ImportError(\n                \"You must install the `bitsandbytes` package to use the 8-bit Adam. \"\n                \"Install with: `pip install bitsandbytes`\"\n            )\n    if name == OptimizerName.ADAMW_8BIT_BNB.value:\n        try:\n            from bitsandbytes.optim import AdamW8bit\n\n            return AdamW8bit\n        except ImportError:\n            raise ImportError(\n                \"You must install the `bitsandbytes` package to use 8-bit AdamW. \"\n                \"Install with: `pip install bitsandbytes`\"\n            )\n    if name == OptimizerName.SGD.value:\n        return torch.optim.SGD\n    supported_optimizers = [o.value for o in OptimizerName]\n    raise ValueError(f\"`{name}` is not a supported optimizer. \" f\"Supported optimizers are: {supported_optimizers}\")\n\n\nclass SchedulerName(str, Enum):\n    \"\"\"Supported scheduler names\"\"\"\n\n    COSINE_ANNEALING = \"cosine_annealing\"\n    LINEAR = \"linear\"\n\n\ndef get_scheduler_class(name: SchedulerName):\n    \"\"\"\n    Returns the scheduler class with the given name\n    \"\"\"\n    if name == SchedulerName.COSINE_ANNEALING:\n        return CosineAnnealingLR\n    if name == SchedulerName.LINEAR:\n        return LinearLR\n    supported_schedulers = [s.value for s in SchedulerName]\n    raise ValueError(f\"`{name}` is not a supported scheduler. \" f\"Supported schedulers are: {supported_schedulers}\")\n\n\nclass Clock:\n    \"\"\"\n    Helper object for keeping track of time for computations.\n    \"\"\"\n\n    def __init__(self):\n        self.start = time.time()\n        self.total_time = 0\n        self.total_samples = 0\n\n    def tick(self, samples: int = 0) -> float:\n        \"\"\"\n        Returns time (s) since last call to tick(). Also records samples processed since last call.\n\n        :param samples: number of samples that have been processed since last call\n        \"\"\"\n        end = time.time()\n        delta = end - self.start\n        self.start = end\n\n        if samples != 0:\n            self.total_time += delta\n            self.total_samples += samples\n\n        return delta\n\n    def get_stat(self, n_samp: int = 1000, reset: bool = False):\n        \"\"\"\n        Returns average time (s) per n_samp samples processed\n\n        :param reset: Reset counts?\n        \"\"\"\n        sec_per_samp = self.total_time / self.total_samples\n\n        if reset:\n            self.total_samples = 0\n            self.total_time = 0\n\n        return sec_per_samp * n_samp\n\n\ndef tree_map(f, tree: Any) -> Any:\n    \"\"\"\n    Apply function f to all leaves in tree\n    \"\"\"\n    if is_dataclass(tree):\n        return tree.__class__(**{k: tree_map(f, v) for k, v in tree.__dict__.items()})\n    elif isinstance(tree, dict):\n        return {k: tree_map(f, v) for k, v in tree.items()}\n    elif isinstance(tree, (list, tuple)):\n        return tree.__class__(tree_map(f, v) for v in tree)\n    else:\n        return f(tree)\n\n\ndef to_device(tree, device, non_blocking=False):\n    \"\"\"\n    Move all tensors in tree to device\n    \"\"\"\n    return tree_map(lambda x: x.to(device, non_blocking=non_blocking), tree)\n\n\ndef filter_non_scalars(xs: Dict) -> Dict:\n    \"\"\"\n    Trims everything that can't be casted to float\n    \"\"\"\n    ys = {}\n    for k, v in xs.items():\n        try:\n            ys[k] = float(v)\n        except TypeError:\n            continue\n\n    return ys\n\n\ndef get_git_tag() -> Tuple[str, str]:\n    \"\"\"\n    Returns commit's short hash and date\n    \"\"\"\n    try:\n        output = subprocess.check_output(\"git log --format='%h/%as' -n1\".split())\n        branch = subprocess.check_output(\"git rev-parse --abbrev-ref HEAD\".split())\n        return branch.decode()[:-1], output.decode()[1:-2]\n    except subprocess.CalledProcessError:\n        return \"unknown\", \"unknown\"\n\n\ndef make_head(n_embd: int, out: int, dtype: type = torch.float32) -> nn.Sequential:\n    \"\"\"Returns a generic sequential MLP head.\"\"\"\n    return nn.Sequential(\n        nn.Linear(n_embd, n_embd * 2, dtype=dtype),\n        nn.ReLU(),\n        nn.Linear(n_embd * 2, out, dtype=dtype),\n    )\n\n\ndef freeze_bottom_causal_layers(model: nn.Module, num_layers_unfrozen: int = 0):\n    \"\"\"Freezes the bottom transformer block layers of the specified model.\"\"\"\n    hidden_layers = hf_get_decoder_blocks(model)\n    if num_layers_unfrozen == 0:\n        hidden_layers_to_freeze = list(hidden_layers)\n    elif num_layers_unfrozen > 0:\n        hidden_layers_to_freeze = list(hidden_layers)[:-num_layers_unfrozen]\n    else:\n        hidden_layers_to_freeze = []\n    for layer in hidden_layers_to_freeze:\n        layer.requires_grad_(False)\n\n\ndef freeze_bottom_seq2seq_layers(model: nn.Module, num_layers_unfrozen: int = 0):\n    \"\"\"Freezes the bottom transformer block layers of the specified model.\"\"\"\n    if num_layers_unfrozen == -1:\n        return\n    shared_embed = model.shared\n    decoder_embed = model.decoder.embed_tokens\n    encoder_blocks = model.encoder.block\n    encoder_norm_layer = model.encoder.final_layer_norm\n    decoder_norm_layer = model.decoder.final_layer_norm\n    decoder_blocks = model.decoder.block[:-num_layers_unfrozen]\n    blocks_to_freeze = (\n        list(encoder_blocks)\n        + list(decoder_blocks)\n        + [shared_embed]\n        + [encoder_norm_layer]\n        + [decoder_norm_layer]\n        + [decoder_embed]\n    )\n    for block in blocks_to_freeze:\n        block.requires_grad_(False)\n\n\ndef rhasattr(obj, attr):\n    \"\"\"A chain-able attribute version of hasattr. For example, to check if\n    `obj` has the attribute `foo.bar.baz`, you can use:\n        `rhasattr(obj, \"foo.bar.baz\")`\n    Reference: https://stackoverflow.com/a/67303315\n    \"\"\"\n    _nested_attrs = attr.split(\".\")\n    _curr_obj = obj\n    for _a in _nested_attrs[:-1]:\n        if hasattr(_curr_obj, _a):\n            _curr_obj = getattr(_curr_obj, _a)\n        else:\n            return False\n    return hasattr(_curr_obj, _nested_attrs[-1])\n\n\ndef rgetattr(obj, attr: str, *args) -> object:\n    \"\"\"A chain-able attribute version of getattr. For example, to get the\n    attribute `foo.bar.baz` from `obj`, you can use:\n        `rgetattr(obj, \"foo.bar.baz\")`\n    Reference: https://stackoverflow.com/a/31174427\n    \"\"\"\n\n    def _getattr(obj, attr):\n        return getattr(obj, attr, *args)\n\n    return functools.reduce(_getattr, [obj] + attr.split(\".\"))\n\n\ndef findattr(obj, attrs: Tuple[str]) -> Union[object, None]:\n    for attr in attrs:\n        if rhasattr(obj, attr):\n            return rgetattr(obj, attr)\n    raise ValueError(f\"Could not find an attribute from `{attrs}` in `{obj}`\")\n\n\ndef hf_get_decoder(model: nn.Module) -> nn.Module:\n    \"\"\"Returns the causal decoder backbone of the specified HuggingFace transformers\n    model.\n    NOTE: Different model configurations have different causal decoder attribute\n    names.\n        - transformer: (GPT2LMHeadModel, GPTJConfig)\n        - model.decoder: (OPTConfig, BloomConfig)\n        - gpt_neox: (GPTNeoXConfig)\n    \"\"\"\n    decoder_attrs = (\"transformer\", \"model.decoder\", \"gpt_neox\", \"decoder\")\n    return findattr(model, decoder_attrs)\n\n\ndef hf_get_decoder_final_norm(model: nn.Module) -> float:\n    \"\"\"Returns the final (layer) norm of the specified decoder.\n    NOTE: Different model configurations have different final norm attribute names.\n        - transformer.ln_f: (GPT2LMHeadModel, GPTJForCausalLM)\n        - model.decoder.final_layer_norm: (OPTForCausalLM)\n        - gpt_neox.layers.final_layer_norm: (GPTNeoXForCausalLM)\n    \"\"\"\n    norm_attrs = (\n        \"transformer.ln_f\",\n        \"model.decoder.final_layer_norm\",\n        \"decoder.final_layer_norm\",\n        \"gpt_neox.final_layer_norm\",\n    )\n    return findattr(model, norm_attrs)\n\n\ndef hf_get_decoder_blocks(model: nn.Module) -> Tuple[nn.Module]:\n    \"\"\"Returns the decoder hidden layers of the specified model.\n    NOTE: Different model configurations have different hidden layer attribute names.\n        - transformer.h: (BloomForCausalLM, GPT2LMHeadModel, GPTJForCausalLM)\n        - model.decoder.layers: (OPTForCausalLM)\n        - gpt_neox.layers: (GPTNeoXForCausalLM)\n        - decoder.block: (T5ForConditionalGeneration)\n    \"\"\"\n    hidden_layers_attrs = (\n        \"h\",\n        \"layers\",\n        \"decoder.layers\",\n        \"transformer.h\",\n        \"model.decoder.layers\",\n        \"gpt_neox.layers\",\n        \"decoder.block\",\n        \"glm.transformer.layers\"\n    )\n    return findattr(model, hidden_layers_attrs)\n\n\ndef hf_get_lm_head(model: nn.Module) -> nn.Module:\n    \"\"\"Returns the language modeling (lm) head of the specified HuggingFace\n    transformers model.\n    NOTE: Different model configurations have different `lm_head` attribute names.\n        - lm_head: (GPT2LMHeadModel, BloomForCausalLM)\n        - embed_out: (GPTNeoXForCausalLM)\n    \"\"\"\n    return model.get_output_embeddings()\n\n\ndef hf_get_hidden_size(config: transformers.PretrainedConfig) -> int:\n    \"\"\"Returns the hidden layer dimensionality of the model architecture specified\n    by the HuggingFace transformers config.\n    NOTE: Different model configurations have different hidden size attribute names.\n        - hidden_size: (OPTConfig, BloomConfig)\n        - n_embd: (GPT2Config, GPTJConfig)\n        - d_model: (PegasusConfig, XLNetConfig)\n    \"\"\"\n    hidden_size_attrs = (\"hidden_size\", \"n_embd\", \"d_model\")\n    return findattr(config, hidden_size_attrs)\n\n\ndef hf_get_num_hidden_layers(config: transformers.PretrainedConfig) -> int:\n    \"\"\"Returns the number of hidden layers in the model architecture specified\n    by the HuggingFace transformers config.\n    NOTE: Different model configurations have different number-of-layers attribute\n    names.\n        - num_hidden_layers: (GPTNeoXConfig, OPTConfig)\n        - n_layer: (GPT2Config, GPTJConfig, BloomConfig)\n    \"\"\"\n    num_hidden_layers_attrs = (\"num_hidden_layers\", \"n_layer\")\n    return findattr(config, num_hidden_layers_attrs)\n\n\ndef get_global_statistics(xs: torch.Tensor) -> Tuple[float, float, int]:\n    \"\"\"\n    Computes element-wise mean and variance of the tensor across processes\n    \"\"\"\n    sum_and_count = torch.tensor([xs.sum(), xs.numel()], device=xs.device)\n    dist.all_reduce(sum_and_count, dist.ReduceOp.SUM)\n    global_sum, count = sum_and_count\n    global_mean = global_sum / count\n\n    sum_var = torch.sum((xs - global_mean) ** 2)\n    dist.all_reduce(sum_var, dist.ReduceOp.SUM)\n    global_var = sum_var / count\n    return global_mean, global_var, count\n\n\ndef whiten(xs: torch.Tensor, shift_mean=True, distributed=True) -> torch.Tensor:\n    \"\"\"Whitens values\"\"\"\n    if distributed and dist.is_initialized():\n        mean, var, _ = get_global_statistics(xs)\n    else:\n        var, mean = torch.var_mean(xs)\n\n    whitened = (xs - mean) * torch.rsqrt(var + 1e-8)\n    if not shift_mean:\n        whitened += mean\n    return whitened\n\n\ndef logprobs_of_labels(logits, labels):\n    \"\"\"Log probabilities of the labels\n\n    These are calculated from the logits.\"\"\"\n    logprobs = F.log_softmax(logits, dim=-1)\n    logprobs_labels = torch.gather(logprobs, dim=-1, index=labels.unsqueeze(-1))\n    return logprobs_labels.squeeze(-1)\n\n\ndef flatten_dict(\n    d: Union[dict, MutableMapping],\n    parent_key: str = \"\",\n    sep: str = \"/\",\n) -> dict:\n    # From: https://stackoverflow.com/a/6027615\n    items = []\n    for k, v in d.items():\n        new_key = parent_key + sep + k if parent_key else k\n        if isinstance(v, MutableMapping):\n            items.extend(flatten_dict(v, new_key, sep=sep).items())\n        else:\n            items.append((new_key, v))\n    return dict(items)\n\n\ndef get_tensor_stats(xs: torch.Tensor, mask: torch.Tensor, n: int):\n    mean = (xs * mask).sum() / n\n    return dict(\n        mean=mean,\n        min=torch.where(mask.bool(), xs, np.inf).min(),\n        max=torch.where(mask.bool(), xs, -np.inf).max(),\n        std=torch.sqrt(((xs - mean) * mask).pow(2).sum() / n),\n    )\n\n\nclass RunningMoments:\n    def __init__(self):\n        \"\"\"\n        Calculates the running mean and standard deviation of a data stream. Modified version of\n        https://github.com/DLR-RM/stable-baselines3/blob/a6f5049a99a4c21a6f0bcce458ca3306cef310e0/stable_baselines3/common/running_mean_std.py\n        \"\"\"\n        self.mean = 0\n        self.std = 1\n        self.var = 1\n        self.count = 1e-24\n\n    def update(self, xs: torch.Tensor) -> Tuple[float, float]:\n        \"\"\"Updates running moments from batch's moments computed across ranks\"\"\"\n        if dist.is_initialized():\n            xs_mean, xs_var, xs_count = get_global_statistics(xs)\n        else:\n            xs_count = xs.numel()\n            xs_var, xs_mean = torch.var_mean(xs, unbiased=False)\n\n        delta = xs_mean - self.mean\n        tot_count = self.count + xs_count\n\n        new_sum = xs_var * xs_count\n        # correct old_sum deviation accounting for the new mean\n        old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count\n        tot_sum = old_sum + new_sum\n\n        self.mean += delta * xs_count / tot_count\n        self.var = tot_sum / tot_count\n        self.std = (self.var * tot_count / (tot_count - 1)).sqrt()\n        self.count = tot_count\n\n        return xs_mean, (xs_var * xs_count / (xs_count - 1)).sqrt()\n\n\n# OpenDelta utilities\n\n\nMODIFIED_MODULES_DICT = {\n    \"gptj\": {\n        \"attention\": [\"attn.q_proj\", \"attn.k_proj\", \"attn.v_proj\"],\n        \"mlp\": [\"mlp.fc_in\", \"mlp.fc_out\"],\n        \"all\": [\n            \"attn.q_proj\",\n            \"attn.k_proj\",\n            \"attn.v_proj\",\n            \"attn.out_proj\",\n            \"mlp.fc_in\",\n            \"mlp.fc_out\",\n        ],\n    },\n    \"gpt_neox\": {\n        \"attention\": [\"attention.query_key_value\"],\n        \"mlp\": [\"mlp.dense_h_to_4h\", \"mlp.dense_4h_to_h\"],\n        \"all\": [\n            \"attention.query_key_value\",\n            \"attention.dense\",\n            \"mlp.dense_h_to_4h\",\n            \"mlp.dense_4h_to_h\",\n        ],\n    },\n    \"opt\": {\n        \"attention\": [\n            \"self_attn.k_proj\",\n            \"self_attn.v_proj\",\n            \"self_attn.q_proj\",\n            \"self_attn.out_proj\",\n        ],\n        \"mlp\": [\"fc1\", \"fc2\"],\n        \"all\": [\n            \"self_attn.k_proj\",\n            \"self_attn.v_proj\",\n            \"self_attn.q_proj\",\n            \"self_attn.out_proj\",\n            \"fc1\",\n            \"fc2\",\n        ],\n    },\n    \"bloom\": {\n        \"attention\": [\"self_attention.query_key_value\", \"self_attention.dense\"],\n        \"mlp\": [\"mlp.dense_h_to_4h\", \"mlp.dense_4h_to_h\"],\n        \"all\": [\n            \"self_attention.query_key_value\",\n            \"self_attention.dense\",\n            \"mlp.dense_h_to_4h\",\n            \"mlp.dense_4h_to_h\",\n        ],\n    },\n    \"t5\": {\n        \"attention\": [\n            \"layer.0.SelfAttention.q\",\n            \"layer.0.SelfAttention.k\",\n            \"layer.0.SelfAttention.v\",\n            \"layer.0.SelfAttention.o\",\n            \"layer.1.EncDecAttention.q\",\n            \"layer.1.EncDecAttention.k\",\n            \"layer.1.EncDecAttention.v\",\n            \"layer.1.EncDecAttention.o\",\n        ],\n        \"mlp\": [\n            \"layer.2.DenseReluDense.wo\",\n            \"layer.2.DenseReluDense.wi_0\",\n            \"layer.2.DenseReluDense.wi_1\",\n        ],\n        \"all\": [\n            \"layer.0.SelfAttention.q\",\n            \"layer.0.SelfAttention.k\",\n            \"layer.0.SelfAttention.v\",\n            \"layer.0.SelfAttention.o\",\n            \"layer.1.EncDecAttention.q\",\n            \"layer.1.EncDecAttention.k\",\n            \"layer.1.EncDecAttention.v\",\n            \"layer.1.EncDecAttention.o\",\n            \"layer.2.DenseReluDense.wo\",\n            \"layer.2.DenseReluDense.wi_0\",\n            \"layer.2.DenseReluDense.wi_1\",\n        ],\n    },\n}\n\n\ndef generate_layer_regex(config: transformers.PretrainedConfig, num_layers_unfrozen: int = -1) -> str:\n    \"\"\"Generates a regex range for the specified number of learnable layers.\"\"\"\n    if num_layers_unfrozen == -1:\n        return \"(\\d)+.\"\n    num_hidden_layers = hf_get_num_hidden_layers(config)\n    start_layer = num_hidden_layers - num_layers_unfrozen\n    if start_layer < 0:\n        raise Exception(\"Number of layers unfrozen cannot be greater than number of layers in the model\")\n    pattern = f\"(?:{regex_for_range(start_layer, num_hidden_layers - 1)}).\"\n    return f\"{pattern}\"\n\n\ndef get_delta_modified_modules(\n    config: transformers.PretrainedConfig,\n    modified_modules: List[str],\n    num_layers_unfrozen: int = -1,\n) -> List[str]:\n    \"\"\"Returns a list of module names to be modified for a given delta method with\n    the specified number of learnable layers.\"\"\"\n    unfrozen_layers_pattern = generate_layer_regex(config, num_layers_unfrozen)\n\n    # [r] for regex as per https://github.com/thunlp/OpenDelta/blob/main/opendelta/utils/name_based_addressing.py#L20\n    regex_prefix = \"[r]\"\n    # TODO (jon-tow): `decoder.block.` is hardcoded to support T5 layer naming.\n    decoder_prefix = \"decoder.block.\" if config.is_encoder_decoder else \"\"\n    module_list = [regex_prefix + decoder_prefix + unfrozen_layers_pattern + module for module in modified_modules]\n    return module_list\n\n\ndef get_delta_model_class(model_type: str):\n    if not HAS_OPENDELTA:\n        raise ValueError(\"OpenDelta package required to train with delta models. https://github.com/thunlp/OpenDelta.\")\n    delta_models = {\n        \"bitfit\": BitFitModel,\n        \"adapter\": AdapterModel,\n        \"prefix\": PrefixModel,\n        \"lora\": LoraModel,\n        \"softprompt\": SoftPromptModel,\n    }\n    return delta_models[model_type]\n\n\ndef parse_delta_kwargs(\n    config: transformers.PretrainedConfig,\n    delta_kwargs: Dict[str, Any],\n    num_layers_unfrozen: int = -1,\n) -> Tuple[str, Dict[str, Any]]:\n    \"\"\"Parses through delta kwargs to get delta type and proper modified modules.\"\"\"\n    # This function is needed to parse through the `delta_kwargs` in order to:\n    # 1) Get the `delta_type` method name to access the correct `delta_model_class`\n    # 2a) Accept user specified `modified_modules` and if not provided use the `trlx` default mapping\n    # 2b) Convert the list of `modified_modules` to a range of layers that fit within the range\n    #    of learnable layers as specified by `num_layers_unfrozen`\n\n    # Pop `delta_type` to allow passing the kwargs to the model constructor since\n    # `delta_type` is not a valid argument of the constructor\n    delta_type = delta_kwargs.pop(\"delta_type\")\n    assert delta_type in [\"lora\"], \"Only `LoRA` based delta models are supported\"\n\n    # Use `trlx` default modified modules if none are specified\n    modified_modules = delta_kwargs.get(\"modified_modules\", \"all\")\n    if modified_modules in [\"all\", \"attention\", \"mlp\"]:\n        if config.model_type not in MODIFIED_MODULES_DICT:\n            raise ValueError(\n                f\"Model type `{config.model_type}` is not currently supported for \"\n                \"delta training with default modified modules.\"\n            )\n        modified_modules = MODIFIED_MODULES_DICT[config.model_type][modified_modules]\n    # Update the `modified_modules` with the correct layer ranges\n    delta_kwargs[\"modified_modules\"] = get_delta_modified_modules(\n        config, modified_modules, num_layers_unfrozen=num_layers_unfrozen\n    )\n\n    return delta_type, delta_kwargs\n\n\ndef regex_for_range(min_: int, max_: int) -> str:  # noqa\n    \"\"\"Returns a regex that matches all numbers in the given range.\n\n    Example: regex_for_range(12, 34) -> \"1[2-9]|2\\d|3[0-4]\"\n\n    Copyright (c) 2013, Dmitry Voronin. All rights reserved.\n    Reference: https://github.com/voronind/range-regex\n    \"\"\"\n\n    def split_to_patterns(min_, max_):\n        subpatterns = []\n        start = min_\n        for stop in split_to_ranges(min_, max_):\n            subpatterns.append(range_to_pattern(start, stop))\n            start = stop + 1\n        return subpatterns\n\n    def split_to_ranges(min_, max_):\n        stops = {max_}\n        nines_count = 1\n        stop = fill_by_nines(min_, nines_count)\n        while min_ <= stop < max_:\n            stops.add(stop)\n            nines_count += 1\n            stop = fill_by_nines(min_, nines_count)\n        zeros_count = 1\n        stop = fill_by_zeros(max_ + 1, zeros_count) - 1\n        while min_ < stop <= max_:\n            stops.add(stop)\n            zeros_count += 1\n            stop = fill_by_zeros(max_ + 1, zeros_count) - 1\n        stops = list(stops)\n        stops.sort()\n        return stops\n\n    def fill_by_nines(integer, nines_count):\n        return int(str(integer)[:-nines_count] + \"9\" * nines_count)\n\n    def fill_by_zeros(integer, zeros_count):\n        return integer - integer % 10**zeros_count\n\n    def range_to_pattern(start, stop):\n        pattern = \"\"\n        any_digit_count = 0\n        for start_digit, stop_digit in zip(str(start), str(stop)):\n            if start_digit == stop_digit:\n                pattern += start_digit\n            elif start_digit != \"0\" or stop_digit != \"9\":\n                pattern += \"[{}-{}]\".format(start_digit, stop_digit)\n            else:\n                any_digit_count += 1\n        if any_digit_count:\n            pattern += r\"\\d\"\n        if any_digit_count > 1:\n            pattern += \"{{{}}}\".format(any_digit_count)\n        return pattern\n\n    positive_subpatterns = []\n    negative_subpatterns = []\n\n    if min_ < 0:\n        min__ = 1\n        if max_ < 0:\n            min__ = abs(max_)\n        max__ = abs(min_)\n        negative_subpatterns = split_to_patterns(min__, max__)\n        min_ = 0\n    if max_ >= 0:\n        positive_subpatterns = split_to_patterns(min_, max_)\n\n    negative_only_subpatterns = [\"-\" + val for val in negative_subpatterns if val not in positive_subpatterns]\n    positive_only_subpatterns = [val for val in positive_subpatterns if val not in negative_subpatterns]\n    intersected_subpatterns = [\"-?\" + val for val in negative_subpatterns if val in positive_subpatterns]\n    subpatterns = negative_only_subpatterns + intersected_subpatterns + positive_only_subpatterns\n    return \"|\".join(subpatterns)\n\n\ndef get_optimizer_grouped_parameters(model,\n                                     weight_decay,\n                                     no_decay_name_list=[\"bias\", \"LayerNorm.weight\"]):\n    optimizer_grouped_parameters = [\n        {\n            \"params\": [\n                p for n, p in model.named_parameters()\n                if (not any(nd in n\n                            for nd in no_decay_name_list) and p.requires_grad)\n            ],\n            \"weight_decay\":\n                weight_decay,\n        },\n        {\n            \"params\": [\n                p for n, p in model.named_parameters()\n                if (any(nd in n\n                        for nd in no_decay_name_list) and p.requires_grad)\n            ],\n            \"weight_decay\":\n                0.0,\n        },\n    ]\n    return optimizer_grouped_parameters\n\n\ndef get_all_reduce_mean(tensor):\n    torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)\n    tensor = tensor / torch.distributed.get_world_size()\n    return tensor\n\n\ndef moving_average(model, model_ema, beta=0.992, device=None, zero_stage=0):\n    zero_stage_3 = (zero_stage == 3)\n    with torch.no_grad():\n        for param, param_ema in zip(model.parameters(),\n                                    model_ema.parameters()):\n            # TODO: use prefiltering for efficiency\n            params_to_fetch = _z3_params_to_fetch([param, param_ema\n                                                   ]) if zero_stage_3 else []\n            should_gather_param = len(params_to_fetch) > 0\n            with deepspeed.zero.GatheredParameters(\n                    params_to_fetch, enabled=should_gather_param):\n                data = param.data\n                if device is not None:\n                    data = data.to(device)\n                param_ema.data.copy_(torch.lerp(data, param_ema.data, beta))\n\n\ndef save_hf_format(model, tokenizer, args, sub_folder=\"\"):\n    # used to save huggingface format, so we can use it for hf.from_pretrained\n    model_to_save = model.module if hasattr(model, 'module') else model\n    CONFIG_NAME = \"config.json\"\n    WEIGHTS_NAME = \"pytorch_model.bin\"\n    output_dir = os.path.join(args.output_dir, sub_folder)\n    try:\n        os.makedirs(output_dir)\n    except FileExistsError:\n        pass\n    output_model_file = os.path.join(output_dir, WEIGHTS_NAME)\n    output_config_file = os.path.join(output_dir, CONFIG_NAME)\n    save_dict = model_to_save.state_dict()\n    for key in list(save_dict.keys()):\n        if \"lora\" in key:\n            del save_dict[key]\n    torch.save(save_dict, output_model_file)\n    model_to_save.config.to_json_file(output_config_file)\n    tokenizer.save_vocabulary(output_dir)\n\n\ndef _z3_params_to_fetch(param_list):\n    return [\n        p for p in param_list\n        if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE\n    ]\n\n\ndef save_zero_three_model(model_ema, global_rank, save_dir, zero_stage=0):\n    zero_stage_3 = (zero_stage == 3)\n    os.makedirs(save_dir, exist_ok=True)\n    WEIGHTS_NAME = \"pytorch_model.bin\"\n    output_model_file = os.path.join(save_dir, WEIGHTS_NAME)\n\n    model_to_save = model_ema.module if hasattr(model_ema, 'module') else model_ema\n    if not zero_stage_3:\n        if global_rank == 0:\n            torch.save(model_to_save.state_dict(), output_model_file)\n    else:\n        output_state_dict = {}\n        for k, v in model_to_save.named_parameters():\n\n            if hasattr(v, 'ds_id'):\n                with deepspeed.zero.GatheredParameters(_z3_params_to_fetch([v]),\n                                                       enabled=zero_stage_3):\n                    v_p = v.data#.cpu()\n            else:\n                v_p = v#.cpu()\n            if global_rank == 0 and \"lora\" not in k:\n                output_state_dict[k] = v_p\n        if global_rank == 0:\n            torch.save(output_state_dict, output_model_file)\n        del output_state_dict\n\n\ndef sorted_checkpoints(output_dir=None, checkpoint_prefix=\"checkpoint\", use_mtime=False) -> List[str]:\n    ordering_and_checkpoint_path = []\n\n    glob_checkpoints = [str(x) for x in Path(output_dir).glob(f\"{checkpoint_prefix}-*\") if os.path.isdir(x)]\n\n    for path in glob_checkpoints:\n        if use_mtime:\n            ordering_and_checkpoint_path.append((os.path.getmtime(path), path))\n        else:\n            regex_match = re.match(f\".*{checkpoint_prefix}-([0-9]+)\", path)\n            if regex_match is not None and regex_match.groups() is not None:\n                ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))\n\n    checkpoints_sorted = sorted(ordering_and_checkpoint_path)\n    checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]\n\n    return checkpoints_sorted\n\n\ndef rotate_checkpoints(save_total_limit, use_mtime=False, output_dir=None, best_model_checkpoint=None) -> None:\n    if save_total_limit is None or save_total_limit <= 0:\n        return\n\n    # Check if we should delete older checkpoint(s)\n    checkpoints_sorted = sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)\n    if len(checkpoints_sorted) <= save_total_limit:\n        return\n\n    # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which\n    # we don't do to allow resuming.\n    save_total_limit_tmp = save_total_limit\n    if (\n            best_model_checkpoint is not None\n            and save_total_limit == 1\n            and checkpoints_sorted[-1] != best_model_checkpoint\n    ):\n        save_total_limit_tmp = 2\n\n    number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit_tmp)\n    checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]\n    for checkpoint in checkpoints_to_be_deleted:\n        print(f\"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit\")\n        shutil.rmtree(checkpoint, ignore_errors=True)\n\n\ndef qwen_make_context(\n        tokenizer: PreTrainedTokenizer,\n        query: str,\n        history: List[Tuple[str, str]] = None,\n        system: str = \"\",\n        max_window_size: int = 6144,\n        chat_format: str = \"chatml\",\n):\n    if history is None:\n        history = []\n\n    if chat_format == \"chatml\":\n        im_start, im_end = \"<|im_start|>\", \"<|im_end|>\"\n\n        def _tokenize_str(content_1, content_2=None,\n                          add_special_tokens=True,\n                          truncation=False,\n                          max_length=None):\n            # if content_2 is not None:\n            #     content = f\"{im_start}{content_1}{im_end}\\n{im_start}{content_2}{im_end}\\n\"\n            # else:\n            #     content = f\"{im_start}{content_1}{im_end}\\n\"\n            token_ids = tokenizer.encode(content_1, content_2,\n                                         add_special_tokens=add_special_tokens,\n                                         max_length=max_length,\n                                         truncation=truncation)\n            content = tokenizer.decode(token_ids)\n            return content, token_ids\n\n        # system message\n        system_text, system_tokens = _tokenize_str(f\"system\\n{system}\")\n        # current-turn user query\n        max_query_length = max_window_size - len(system_tokens) + 2\n        assert max_query_length > 0, f\"System message length ({len(system_tokens)}) has exceeded max window size ({max_window_size})\"\n        query_text, query_tokens = _tokenize_str(f\"user\\n{query}\", f\"assistant\\n\",\n                                                 truncation=True, max_length=max_query_length)\n        # remove additional \"<|im_end|>\" and \"\\n\"\n        query_text = query_text.strip(\"\\n\").strip(im_end)\n        query_tokens = query_tokens[:-2]\n        # history\n        raw_text = \"\"\n        context_tokens = []\n\n        for turn_query, turn_response in reversed(history):\n            prev_chat, next_context_tokens = _tokenize_str(f\"user\\n{turn_query}\", f\"assistant\\n{turn_response}\")\n            current_context_size = (\n                    len(system_tokens) + len(query_tokens) + len(next_context_tokens) + len(context_tokens)\n            )\n            if current_context_size < max_window_size:\n                context_tokens = next_context_tokens + context_tokens\n                raw_text = prev_chat + raw_text\n            else:\n                break\n\n        raw_text = system_text + raw_text + query_text\n        context_tokens = system_tokens + context_tokens + query_tokens\n\n    elif chat_format == \"raw\":\n        raw_text = query\n        context_tokens = tokenizer.encode(raw_text)\n    else:\n        raise NotImplementedError(f\"Unknown chat format {chat_format!r}\")\n\n    return raw_text, context_tokens\n"
  },
  {
    "path": "src/utils/nlp_utils.py",
    "content": "\nimport re\n\nCLEAN_TEXT_PATTERN = re.compile(r\"[\\r\\n]\")\n\n\ndef clean_text(text):\n    if not isinstance(text, str):\n        return \"\"\n    return CLEAN_TEXT_PATTERN.sub(\"\", text)\n"
  }
]