Showing preview only (425K chars total). Download the full file or copy to clipboard to get everything.
Repository: jongwooko/distillm
Branch: master
Commit: d47e77ff9d27
Files: 130
Total size: 390.6 KB
Directory structure:
gitextract_j6ug111_/
├── .gitignore
├── README.md
├── arguments.py
├── configs/
│ ├── deepspeed/
│ │ ├── ds_config.json
│ │ ├── ds_config_fp32.json
│ │ ├── ds_config_zero2.json
│ │ └── ds_config_zero2_offload.json
│ └── hostfiles/
│ ├── node_0_1
│ ├── node_0_1_2_3
│ ├── node_1_2
│ └── node_2_3
├── data_utils/
│ ├── distributed_indexed.py
│ ├── indexed_dataset.py
│ ├── lm_datasets.py
│ └── prompt_datasets.py
├── distillm/
│ ├── __init__.py
│ ├── buffer.py
│ ├── losses.py
│ └── sampler.py
├── evaluate.py
├── evaluate_main.py
├── finetune.py
├── generate.py
├── install.sh
├── minillm/
│ ├── __init__.py
│ ├── data_types.py
│ ├── losses.py
│ ├── model.py
│ ├── pipelines.py
│ ├── reward.py
│ ├── sampler.py
│ ├── storages.py
│ ├── trainer.py
│ └── utils.py
├── rouge_metric.py
├── scripts/
│ ├── gpt2/
│ │ ├── distillm/
│ │ │ ├── train_0.1B_1.5B.sh
│ │ │ ├── train_0.3B_1.5B.sh
│ │ │ └── train_0.7B_1.5B.sh
│ │ ├── eval/
│ │ │ ├── eval_main_dolly.sh
│ │ │ ├── eval_main_self_inst.sh
│ │ │ ├── eval_main_sinst.sh
│ │ │ ├── eval_main_uinst.sh
│ │ │ ├── eval_main_vicuna.sh
│ │ │ └── run_eval.sh
│ │ ├── gkd/
│ │ │ ├── gkd_base.sh
│ │ │ ├── gkd_large.sh
│ │ │ └── gkd_medium.sh
│ │ ├── imitkd/
│ │ │ ├── imitkd_base.sh
│ │ │ ├── imitkd_large.sh
│ │ │ └── imitkd_medium.sh
│ │ ├── init/
│ │ │ ├── init_base.sh
│ │ │ ├── init_large.sh
│ │ │ └── init_medium.sh
│ │ ├── kd/
│ │ │ ├── kd_base.sh
│ │ │ ├── kd_large.sh
│ │ │ └── kd_medium.sh
│ │ ├── minillm/
│ │ │ ├── train_base_xl.sh
│ │ │ ├── train_large_xl.sh
│ │ │ └── train_medium_xl.sh
│ │ ├── seqkd/
│ │ │ ├── seqkd_base.sh
│ │ │ ├── seqkd_large.sh
│ │ │ └── seqkd_medium.sh
│ │ ├── sft/
│ │ │ ├── sft_base.sh
│ │ │ ├── sft_large.sh
│ │ │ ├── sft_medium.sh
│ │ │ └── sft_xlarge.sh
│ │ └── tools/
│ │ ├── generate_data_seqkd.sh
│ │ ├── process_data_dolly.sh
│ │ ├── process_data_pretrain.sh
│ │ └── process_pseudo_data_seqkd.sh
│ ├── openllama2/
│ │ ├── distillm/
│ │ │ └── train_3B_7B_teacher_lora.sh
│ │ ├── eval/
│ │ │ ├── eval_main_dolly_lora.sh
│ │ │ ├── eval_main_self_inst_lora.sh
│ │ │ ├── eval_main_sinst_lora.sh
│ │ │ ├── eval_main_uinst_lora.sh
│ │ │ ├── eval_main_vicuna_lora.sh
│ │ │ └── run_eval.sh
│ │ ├── gkd/
│ │ │ └── gkd_3B_7B_teacher_lora.sh
│ │ ├── imitkd/
│ │ │ └── imitkd_3B_7B_teacher_lora.sh
│ │ ├── init/
│ │ │ └── sft_3B_lora.sh
│ │ ├── kd/
│ │ │ └── kd_3B_7B_teacher_lora.sh
│ │ ├── minillm/
│ │ │ └── train_3B_7B_lora.sh
│ │ ├── seqkd/
│ │ │ └── seqkd_3B_7B_teacher_lora.sh
│ │ ├── sft/
│ │ │ ├── sft_3B_lora.sh
│ │ │ └── sft_7B_lora.sh
│ │ └── tools/
│ │ ├── generate_data_seqkd.sh
│ │ ├── process_data_dolly.sh
│ │ ├── process_data_pretrain.sh
│ │ └── process_pseudo_data_seqkd.sh
│ └── opt/
│ ├── distillm/
│ │ ├── train_0.1B_2.7B.sh
│ │ ├── train_0.3B_2.7B.sh
│ │ └── train_1.3B_2.7B.sh
│ ├── eval/
│ │ ├── eval_main_dolly.sh
│ │ ├── eval_main_self_inst.sh
│ │ ├── eval_main_sinst.sh
│ │ ├── eval_main_uinst.sh
│ │ ├── eval_main_vicuna.sh
│ │ └── run_eval.sh
│ ├── gkd/
│ │ ├── gkd_0.1B_2.7B.sh
│ │ ├── gkd_0.3B_2.7B.sh
│ │ └── gkd_1.3B_2.7B.sh
│ ├── imitkd/
│ │ ├── imitkd_0.1B_2.7B.sh
│ │ ├── imitkd_0.3B_2.7B.sh
│ │ └── imitkd_1.3B_2.7B.sh
│ ├── init/
│ │ ├── init_0.1B.sh
│ │ ├── init_0.3B.sh
│ │ └── init_1.3B.sh
│ ├── kd/
│ │ ├── kd_0.1B_2.7B.sh
│ │ ├── kd_0.3B_2.7B.sh
│ │ └── kd_1.3B_2.7B.sh
│ ├── minillm/
│ │ ├── train_0.1B_2.7B.sh
│ │ ├── train_0.3B_2.7B.sh
│ │ └── train_1.3B_2.7B.sh
│ ├── seqkd/
│ │ ├── seqkd_0.1B_2.7B.sh
│ │ ├── seqkd_0.3B_2.7B.sh
│ │ └── seqkd_1.3B_2.7B.sh
│ ├── sft/
│ │ ├── sft_0.1B.sh
│ │ ├── sft_0.3B.sh
│ │ ├── sft_1.3B.sh
│ │ └── sft_2.7B.sh
│ └── tools/
│ ├── generate_data_seqkd.sh
│ ├── process_data_dolly.sh
│ ├── process_data_pretrain.sh
│ └── process_pseudo_data_seqkd.sh
├── tools/
│ ├── convert_mp.py
│ ├── get_openwebtext.py
│ ├── process_data_dolly.py
│ └── process_data_pretrain.py
├── train_minillm.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks
# Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks
### JupyterNotebooks ###
# gitignore template for Jupyter Notebooks
# website: http://jupyter.org/
.ipynb_checkpoints
*/.ipynb_checkpoints/*
# IPython
profile_default/
ipython_config.py
# Remove previous ipynb_checkpoints
# git rm -r .ipynb_checkpoints/
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
# IPython
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
poetry.toml
# ruff
.ruff_cache/
# LSP config files
pyrightconfig.json
# End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks
./results/
./data/
./checkpoints/
./processed_data/
*.tar.gz
================================================
FILE: README.md
================================================
# DistiLLM: Towards Streamlined Distillation for Large Language Models (ICML 2024)
<a href="https://arxiv.org/abs/2402.03898"><img src="https://img.shields.io/badge/Paper-arXiv:2402.03898-Green"></a>
<a href=#bibtex><img src="https://img.shields.io/badge/Paper-BibTex-yellow"></a>
Official PyTorch implementation of **DistiLLM**, as presented in our paper: \
\
**DistiLLM: Towards Streamlined Distillation for Large Language Models** \
*[Jongwoo Ko](https://sites.google.com/view/jongwooko), [Sungnyun Kim](https://sungnyunkim.notion.site/Sungnyun-Kim-4770a0182c47469ebdcd357cde97bd32), Tianyi Chen, Se-Young Yun* \
KAIST AI and Microsoft
## 🚀 Updates
- [x] (25.03.11) DistiLLM-2 paper is out! The preliminary code will be available in this repo, and final code will be available in [here](https://github.com/jongwooko/distillm-2), soon.
- [x] (24.08.12) Remove the dependency on the local transformers, which are outdated. You can work with various types of recent LLMs!
- [x] (24.05.01) Our paper has been accepted in **ICML 2024**. We are open to receiving any discussions and will reflect them in the camera-ready version. Looking forward to seeing you in Vienna!
- [x] (24.03.13) Release [**LoRA checkpoints for OpenLLaMa2-3B**](https://drive.google.com/drive/folders/1Yun1aNpn-mz2h-IVH_VdJ1Jhzm0K55Bo?usp=sharing)
## Environment
```bash
bash install.sh
```
Our code is based on [this commit](https://github.com/huggingface/transformers/commit/85fde09c97213bf7e8625f83096bb2a9e183f987) of HuggingFace Transformers **by following MiniLLM**.
## Data
### Resources
+ The training/evaluation intruction-response data before processing can be downloaded from this [link](https://conversationhub.blob.core.windows.net/beit-share-public/MiniLLM/data.tar?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D).
+ The plain-text corpus $\mathcal{D}_\text{PT}$ can be download from the HugginFace datasets [repository](https://huggingface.co/datasets/openwebtext).
### Data Processing
Get plain-text corpus $\mathcal{D}_\text{PT}$:
```bash
python3 tools/get_openwebtext.py
```
This script will replace the continuous `\n` in each document with a special token "<@x(x!>" and write each document in OpenWebText in a line, which is convenient for parallel processing. In `data/openwebtext/data.txt`, we give an example of the resulting format. You can follow this format to prepare other corpus beyond OpenWebText.
Tokenize the data and store them in binary files:
```bash
bash scripts/gpt2/tools/process_data_dolly.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM} # Process Dolly Train / Validation Data
bash scripts/gpt2/tools/process_data_pretrain.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM} # Process OpenWebText Train / Validation Data
bash scripts/opt/tools/process_data_dolly.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM} # Process Dolly Train / Validation Data
bash scripts/opt/tools/process_data_pretrain.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM} # Process OpenWebText Corpus Train / Validation Data
bash scripts/llama/tools/process_data_dolly.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM} # Process Dolly Train / Validation Data
bash scripts/llama/tools/process_data_pretrain.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM} # Process OpenWebText Corpus Train / Validation Data
```
## Base Pre-trained Models
To run fine-tuning or standard KD baselines, you need to download the model checkpoints from [Huggingface Model Hub] and put them in `checkpoints/`. For example, for gpt2-large, you can download the model from this [link](https://huggingface.co/gpt2-large/tree/main) and put them in `checkpoints/gpt2-large`.
Alternatively, you can also change the `CKPT` variable in each script to the corresponding model name to enable Transformers to download the base models automatically. For example, set `CKPT="gpt2-large"` in `scripts/gpt2/sft/sft_large.sh` causes download of the gpt2-large base model from the HugginFace model hub.
## Train
We provide example commands for GPT-2 models. Similar scripts for model families can be found in `scripts/opt` and `scripts/openllama2`. All our experiments are conducted on 4 \* 40A100, which can be reduced for small models.
### Baselines
The final checkpoints are selected by the **ROUGE-L** scores.
#### Fine-tune the teacher models
```bash
bash scripts/gpt2/sft/sft_xlarge.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
```
#### SFT Baselines
```bash
bash scripts/gpt2/sft/sft_base.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/sft/sft_medium.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/sft/sft_large.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
```
#### KD Baselines
```bash
bash scripts/gpt2/kd/kd_base.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/kd/kd_medium.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/kd/kd_large.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
```
#### SeqKD Baselines
Generate and process responses with the teacher:
```bash
bash scripts/gpt2/tools/generate_data_seqkd.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/tools/process_pseudo_data_seqkd.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
```
Fine-tune the model with SeqKD:
```bash
bash scripts/gpt2/seqkd/seqkd_base.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/seqkd/seqkd_medium.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/seqkd/seqkd_large.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
```
#### Student Initialization
The final checkpoints are selected by the **validation loss**.
```bash
bash scripts/gpt2/init/init_base.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/init/init_medium.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/init/init_large.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
```
#### ImitKD Baselines
```bash
bash scripts/gpt2/imitkd/imitkd_base_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/imitkd/imitkd_medium_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/imitkd/imitkd_large_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
```
#### MiniLLM Baselines
```bash
bash scripts/gpt2/minillm/train_base_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/minillm/train_medium_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/minillm/train_large_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
```
#### GKD Baselines
```bash
bash scripts/gpt2/gkd/gkd_base_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/gkd/gkd_medium_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/gkd/gkd_large_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
```
### DistiLLM
The final checkpoints are selected by the **validation loss**.
```bash
bash scripts/gpt2/init/init_base.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/init/init_medium.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/init/init_large.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
```
The final checkpoints are selected by the **ROUGE-L** scores.
```bash
bash scripts/gpt2/distillm/train_base_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/distillm/train_medium_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/distillm/train_large_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
```
## Run Evaluation
```bash
bash scripts/gpt2/eval/run_eval.sh ${GPU_IDX} ${/PATH/TO/DistiLLM}
bash scripts/opt/eval/run_eval.sh ${GPU_IDX} ${/PATH/TO/DistiLLM}
bash scripts/openllama2/eval/run_eval.sh ${GPU_IDX} ${/PATH/TO/DistiLLM}
```
## Results
DistiLLM outperforms other KD baselines in terms of both generation performance and training speed for various model families such as GPT-2, OPT, and OpenLLaMA.
<p align="center">
<img width="1394" src="https://github.com/jongwooko/distillm/assets/59277369/19ddac5c-4cd6-4d81-99d8-32723a8e60d8">
</p>
## Checkpoints (OpenLLaMA-3B)
We share the LoRA weights for OpenLLaMA-3B in [google drive](https://drive.google.com/drive/folders/1Yun1aNpn-mz2h-IVH_VdJ1Jhzm0K55Bo?usp=sharing).
## Acknowledgement
Our code is based on the code of ICLR2024 [MiniLLM: Knowledge Distillation of Large Language Models](https://arxiv.org/pdf/2306.08543.pdf).
## Star History
[](https://star-history.com/#jongwooko/distillm&Date)
## BibTeX
If you find this repo useful for your research, please consider citing our paper:
```
@inproceedings{kodistillm,
title={DistiLLM: Towards Streamlined Distillation for Large Language Models},
author={Ko, Jongwoo and Kim, Sungnyun and Chen, Tianyi and Yun, Se-Young},
booktitle={Forty-first International Conference on Machine Learning}
}
@article{ko2025distillm2,
title={DistiLLM-2: A Contrastive Approach Boosts the Distillation of LLMs},
author={Jongwoo Ko and Tianyi Chen and Sungnyun Kim and Tianyu Ding and Luming Liang and Ilya Zharkov and Se-Young Yun},
year={2025},
journal={arXiv preprint arXiv:2503.07067}
}
```
## Contact
- Jongwoo Ko: jongwoo.ko@kaist.ac.kr
================================================
FILE: arguments.py
================================================
# coding=utf-8
# Copyright 2020 The OpenBMB team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import deepspeed
import numpy as np
def add_model_args(parser: argparse.ArgumentParser):
"""Model arguments"""
group = parser.add_argument_group('model', 'model configuration')
group.add_argument('--model-path', type=str, help='model path')
group.add_argument("--ckpt-name", type=str)
group.add_argument("--model-type", type=str, default="gpt2")
group.add_argument("--teacher-model-type", type=str, default=None)
group.add_argument("--n-gpu", type=int, default=1)
group.add_argument("--n-nodes", type=int, default=1)
group.add_argument("--teacher-model-path", type=str)
group.add_argument("--teacher-ckpt-name", type=str)
group.add_argument("--teacher-model-fp16", action="store_true")
group.add_argument("--model-parallel", action="store_true")
group.add_argument("--model-parallel-size", type=int, default=None)
group.add_argument("--no-value", action="store_true")
group.add_argument("--dropout-path-rate", type=float, default=None)
group.add_argument("--fp32", action="store_true")
return parser
def add_runtime_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group('runtime', 'runtime configurations')
group.add_argument("--type", type=str, default=None)
group.add_argument("--do-train", action="store_true")
group.add_argument("--do-valid", action="store_true")
group.add_argument("--do-eval", action="store_true")
group.add_argument('--base-path', type=str, default=None, help='Path to the project base directory.')
group.add_argument('--load', type=str, default=None,
help='Path to a directory containing a model checkpoint.')
group.add_argument('--save', type=str, default=None,
help='Output directory to save checkpoints to.')
group.add_argument("--log-interval", type=int, default=10)
group.add_argument("--mid-log-num", type=int, default=4)
group.add_argument('--save-interval', type=int, default=1000,
help='number of iterations between saves')
group.add_argument("--eval-interval", type=int, default=1000)
group.add_argument('--local_rank', type=int, default=None,
help='local rank passed from distributed launcher')
group.add_argument("--save-additional-suffix", type=str, default="")
group.add_argument("--save-rollout", action="store_true")
group.add_argument("--eb-sample-times", type=int, default=3)
return parser
def add_data_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group('data', 'data configurations')
group.add_argument("--data-dir", type=str, default=None)
group.add_argument("--processed-data-dir", type=str, default=None)
group.add_argument("--force-process", action="store_true")
group.add_argument("--force-process-demo", action="store_true")
group.add_argument("--data-process-workers", type=int, default=-1)
group.add_argument("--train-num", type=int, default=-1)
group.add_argument("--train-ratio", type=float, default=1)
group.add_argument("--dev-num", type=int, default=-1)
group.add_argument("--dev-ratio", type=float, default=1)
group.add_argument("--gen-num", type=int, default=-1)
group.add_argument("--data-names", type=str, default=None)
group.add_argument("--prompt-type", type=str, default=None)
group.add_argument("--num-workers", type=int, default=1)
group.add_argument("--max-prompt-length", type=int, default=512)
group.add_argument("--min-prompt-length", type=int, default=128)
group.add_argument("--json-data", action="store_true")
group.add_argument("--bin-data", action="store_true")
group.add_argument("--txt-data", action="store_true")
group.add_argument("--prompt-data-dir", type=str)
group.add_argument("--lm-data-dir", type=str)
group.add_argument("--eval-ppl", action="store_true")
group.add_argument("--eval-rw", action="store_true")
group.add_argument("--eval-gen", action="store_true")
group.add_argument("--only-prompt", action="store_true")
return parser
def add_hp_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group("hp", "hyper parameter configurations")
group.add_argument('--batch-size', type=int, default=32,
help='Data Loader batch size')
group.add_argument('--eval-batch-size', type=int, default=32,
help='Data Loader batch size')
group.add_argument('--clip-grad', type=float, default=1.0,
help='gradient clipping')
group.add_argument('--total-iters', type=int, default=None,
help='total number of iterations')
group.add_argument('--train-iters-per-epoch', type=int, default=-1,
help='total number of iterations per epoch')
group.add_argument('--max-length', type=int, default=1024,
help='max length of input')
group.add_argument('--seed', type=int, default=1234,
help='random seed for reproducibility')
group.add_argument("--seed-order", type=int, default=42)
group.add_argument("--seed-data", type=int, default=42)
group.add_argument("--seed-ppo", type=int, default=42)
group.add_argument("--seed-lm", type=int, default=7)
group.add_argument('--epochs', type=int, default=None,
help='total number of epochs to train over all training runs')
group.add_argument('--training-epochs', type=int, default=10000)
group.add_argument("--gradient-accumulation-steps", type=int, default=1)
group.add_argument("--gradient-checkpointing", action="store_true")
group.add_argument("--attn-dtype", default=None)
group.add_argument('--lr', type=float, help='initial learning rate')
group.add_argument("--lr-min", type=float, default=0.0000001)
group.add_argument('--weight-decay', type=float, default=1.0e-2,
help='weight-decay')
group.add_argument('--loss-scale', type=float, default=65536,
help='loss scale')
group.add_argument("--kd-ratio", type=float, default=None)
group.add_argument('--warmup-iters', type=int, default=0,
help='percentage of data to warmup on (.01 = 1% of all '
'training iters). Default 0.01')
group.add_argument('--lr-decay-iters', type=int, default=None,
help='number of iterations to decay LR over,'
' If None defaults to `--train-iters`*`--epochs`')
group.add_argument('--lr-decay-style', type=str, default='noam',
choices=['constant', 'linear', 'cosine', 'exponential', 'noam'],
help='learning rate decay function')
group.add_argument("--scheduler-name", type=str, default="constant_trm")
return parser
def add_ppo_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group('ppo', 'ppo configurations')
group.add_argument("--reward-scaling", type=float, default=None)
group.add_argument("--cliprange-reward", type=float, default=1)
group.add_argument("--ppo-epochs", type=int, default=None)
group.add_argument("--num-rollouts", type=int, default=256)
group.add_argument("--num-rollouts-per-device", type=int, default=None)
group.add_argument("--cliprange", type=float, default=0.2)
group.add_argument("--chunk-size", type=int, default=None)
group.add_argument("--gamma", type=float, default=0.95)
return parser
def add_minillm_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group('minillm', 'minillm configurations')
group.add_argument("--length-norm", action="store_true")
group.add_argument("--single-step-reg", action="store_true")
group.add_argument("--teacher-mixed-alpha", type=float, default=None)
group.add_argument("--lm-coef", type=float, default=1)
return parser
def add_distillm_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group('distillm', 'distillm configurations')
# skew kld
group.add_argument("--skew-alpha", type=float, default=0.1)
# student generation
group.add_argument("--student-gen", action="store_true")
group.add_argument("--gen-top-p", type=float, default=1.0)
group.add_argument("--gen-num-beams", type=int, default=2)
# adaptive threshold
group.add_argument("--mixed-alpha", type=float, default=0.5)
group.add_argument("--loss-eps", type=float, default=0.1)
group.add_argument("--init-threshold", type=float, default=0.0)
# off-policy
group.add_argument("--capacity", type=int, default=1000)
group.add_argument("--replay-ratio", type=str, default="decreasing")
# group.add_argument("--time", action="store_true")
return parser
def add_gen_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group('generation', 'generation configurations')
group.add_argument("--top-k", type=int, default=0)
group.add_argument("--top-p", type=float, default=1.0)
group.add_argument("--do-sample", action="store_true")
group.add_argument("--no-repeat-ngram-size", type=int, default=6)
group.add_argument("--repetition-penalty", type=float, default=None)
group.add_argument("--num-beams", type=int, default=1)
group.add_argument("--temperature", type=float, default=1)
return parser
def add_peft_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group('generation', 'generation configurations')
group.add_argument("--peft", type=str, default=None)
group.add_argument("--peft-lora-r", type=int, default=16)
group.add_argument("--peft-lora-alpha", type=int, default=64)
group.add_argument("--peft-lora-dropout", type=float, default=0.1)
group.add_argument("--peft-name", type=str, default=None)
group.add_argument("--peft-path", type=str, default=None)
group.add_argument("--teacher-peft-name", type=str, default=None)
group.add_argument("--teacher-peft-path", type=str, default=None)
return parser
def get_args():
parser = argparse.ArgumentParser()
parser = add_model_args(parser)
parser = add_runtime_args(parser)
parser = add_data_args(parser)
parser = add_hp_args(parser)
parser = add_ppo_args(parser)
parser = add_minillm_args(parser)
parser = add_distillm_args(parser)
parser = add_gen_args(parser)
parser = add_peft_args(parser)
parser = deepspeed.add_config_arguments(parser)
args, unknown = parser.parse_known_args()
assert all(["--" not in x for x in unknown]), unknown
args.local_rank = int(os.getenv("LOCAL_RANK", "0"))
args.n_gpu = args.n_gpu * args.n_nodes
if args.type == "eval_main":
ckpt_name = None
if args.ckpt_name is not None:
ckpt_name = args.ckpt_name
if args.peft_name is not None:
ckpt_name = args.peft_name
if ckpt_name is not None:
tmp = ckpt_name.split("/")
if tmp[-1].isdigit():
ckpt_name = "_".join(tmp[:-1]) + "/" + tmp[-1]
else:
ckpt_name = "_".join(tmp)
save_path = os.path.join(
args.save,
f"{args.data_names}-{args.max_length}" + (f"-mp{args.model_parallel_size}" if args.model_parallel > 0 else ""),
ckpt_name,
f"{args.seed}",
)
args.save = save_path
elif args.type == "lm":
save_path = os.path.join(
args.save,
(f"{args.ckpt_name}" + f"-{args.peft_name}" if args.peft_name is not None else ""),
(f"e{args.epochs}-bs{args.batch_size}-lr{args.lr}-G{args.gradient_accumulation_steps}-N{args.n_gpu}-NN{args.n_nodes}") + \
(f"-mp{args.model_parallel_size}" if args.model_parallel > 0 else "") + \
(f"-lora-{args.peft_lora_r}-{args.peft_lora_alpha}-{args.peft_lora_dropout}" if args.peft == "lora" else "") + \
args.save_additional_suffix
)
args.save = save_path
elif args.type == "kd":
save_path = os.path.join(
args.save,
(f"{args.ckpt_name}" + f"-{args.peft_name}" if args.peft_name is not None else "" + \
f"-{args.teacher_ckpt_name}" + f"-{args.teacher_peft_name}" if args.teacher_peft_name is not None else ""),
(f"e{args.epochs}-bs{args.batch_size}-lr{args.lr}-G{args.gradient_accumulation_steps}-N{args.n_gpu}-NN{args.n_nodes}-kd{args.kd_ratio}") + \
(f"-mp{args.model_parallel_size}" if args.model_parallel > 0 else "") + \
(f"-lora-{args.peft_lora_r}-{args.peft_lora_alpha}-{args.peft_lora_dropout}" if args.peft == "lora" else "") + \
args.save_additional_suffix
)
args.save = save_path
elif args.type == "gen":
save_path = os.path.join(
args.save,
(f"{args.ckpt_name}"),
(f"t{args.temperature}-l{args.max_length}"),
)
args.save = save_path
elif args.type == "minillm":
ppo_prefix = f"pe{args.ppo_epochs}" + \
(f"_rs{args.reward_scaling}" if args.ppo_epochs is not None else "") + \
(f"_nr{args.num_rollouts}" if args.num_rollouts is not None else "") + \
(f"_ln" if args.length_norm else "") + \
(f"_sr" if args.single_step_reg else "") + \
(f"_tm{args.teacher_mixed_alpha}" if args.teacher_mixed_alpha is not None else "")
save_path = os.path.join(
args.save,
(f"{args.ckpt_name}" + f"-{args.peft_name}" if args.peft_name is not None else "" + \
f"-{args.teacher_ckpt_name}" + f"-{args.teacher_peft_name}" if args.teacher_peft_name is not None else ""),
(f"bs{args.batch_size}-lr{args.lr}-G{args.gradient_accumulation_steps}-N{args.n_gpu}-NN{args.n_nodes}-lm{args.lm_coef}-len{args.max_length}" + \
(f"-mp{args.model_parallel_size}" if args.model_parallel > 0 else "")) + \
(f"-lora-{args.peft_lora_r}-{args.peft_lora_alpha}-{args.peft_lora_dropout}" if args.peft == "lora" else ""),
ppo_prefix + args.save_additional_suffix
)
args.save = save_path
args.num_rollouts_per_device = args.num_rollouts // args.n_gpu
if args.warmup_iters > 0:
assert args.scheduler_name is not None
return args
================================================
FILE: configs/deepspeed/ds_config.json
================================================
{
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"zero_optimization": {
"stage": 1
},
"zero_allow_untested_optimizer": true,
"fp16": {
"enabled": true,
"loss_scale": 0,
"initial_scale_power": 11,
"loss_scale_window": 2000,
"hysteresis": 4
},
"wall_clock_breakdown": false
}
================================================
FILE: configs/deepspeed/ds_config_fp32.json
================================================
{
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"zero_optimization": {
"stage": 1
},
"zero_allow_untested_optimizer": true,
"fp16": {
"enabled": false
},
"wall_clock_breakdown": false
}
================================================
FILE: configs/deepspeed/ds_config_zero2.json
================================================
{
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "none"
},
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
},
"zero_allow_untested_optimizer": true,
"fp16": {
"enabled": true,
"loss_scale": 0,
"initial_scale_power": 11,
"loss_scale_window": 5000,
"hysteresis": 4
},
"wall_clock_breakdown": false
}
================================================
FILE: configs/deepspeed/ds_config_zero2_offload.json
================================================
{
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu"
},
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
},
"zero_force_ds_cpu_optimizer": false,
"zero_allow_untested_optimizer": true,
"fp16": {
"enabled": true,
"loss_scale": 0,
"initial_scale_power": 11,
"loss_scale_window": 5000,
"hysteresis": 4
},
"wall_clock_breakdown": false
}
================================================
FILE: configs/hostfiles/node_0_1
================================================
node-0 slots=8
node-1 slots=8
================================================
FILE: configs/hostfiles/node_0_1_2_3
================================================
node-0 slots=8
node-1 slots=8
node-2 slots=8
node-3 slots=8
================================================
FILE: configs/hostfiles/node_1_2
================================================
node-1 slots=8
node-2 slots=8
================================================
FILE: configs/hostfiles/node_2_3
================================================
node-2 slots=8
node-3 slots=8
================================================
FILE: data_utils/distributed_indexed.py
================================================
# coding=utf-8
# Copyright 2020 The OpenBMB team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import struct
import shutil
from itertools import accumulate
import numpy as np
import torch
import torch.distributed as dist
from utils import print_rank, save_rank
dtypes = {
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
5: np.int64,
6: np.float32,
7: np.double,
8: np.uint16,
9: np.uint32
}
def code(dtype):
for k in dtypes.keys():
if dtypes[k] == dtype:
return k
raise ValueError(dtype)
def index_file_path(prefix_path):
return prefix_path + '.idx'
def data_file_path(prefix_path):
return prefix_path + '.bin'
class DistributedMMapIndexedDataset(torch.utils.data.Dataset):
class Index(object):
_HDR_MAGIC = b'MMIDIDX\x00\x00'
def __init__(self, path):
with open(path, 'rb') as stream:
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, (
'Index file doesn\'t match expected format. '
'Make sure that --dataset-impl is configured properly.'
)
version = struct.unpack('<Q', stream.read(8))
assert (1,) == version
dtype_code, = struct.unpack('<B', stream.read(1))
self._dtype = dtypes[dtype_code]
self._dtype_size = self._dtype().itemsize
self._len = struct.unpack('<Q', stream.read(8))[0]
self._doc_count = struct.unpack('<Q', stream.read(8))[0]
offset = stream.tell()
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
self._bin_buffer = memoryview(self._bin_buffer_mmap)
self._sizes = np.frombuffer(
self._bin_buffer,
dtype=np.int32,
count=self._len,
offset=offset)
self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
offset=offset + self._sizes.nbytes)
self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
@property
def dtype(self):
return self._dtype
@property
def sizes(self):
return self._sizes
@property
def doc_idx(self):
return self._doc_idx
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]
def __len__(self):
return self._len
def __init__(self, path, name, rank_number, rank_total, cache = None):
super().__init__()
self._path = path
self._name = name
self._state = 0
if cache is not None:
self._cache = cache
os.makedirs(self._cache, exist_ok=True)
else:
self._cache = None
self._rank_total = rank_total
self._rank_number = rank_number
self._index = None
self._bin_buffer = None
self._bin_buffer_mmap = None
self.max_state, self.history = self._probe_data_path(self._path, self._name, self._rank_total)
self.total_length = self.history[self.max_state-1][1]
self._do_init(self._path, self._name, self._cache, self._state)
def _probe_data_path(self, path, name, rank_total):
print_rank("Probing Dataset")
state = 0
history = {-1:(0, 0)}
for state in range(np.iinfo(np.int32).max):
source_file = path + name + f"_{state}"
if self.exists(source_file):
index = self.Index(index_file_path(source_file))
history[state] = (history[state-1][1], history[state-1][1] + len(index))
else:
break
print_rank(f"Probing end. Max data state {state}, total length {history[state-1][1]}")
return state, history
def __getstate__(self):
return self._path + self._name + "_%d"%(self._state)
def __setstate__(self, state):
self._state = state
self._do_init(self._path, self._name, self._cache, self._state)
def _do_init(self, path, name, cache, state):
if self._bin_buffer_mmap is not None:
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
if self._index is not None:
del self._index
self._state = state
source_file = path + name + f"_{self._state}"
self._index = self.Index(index_file_path(source_file))
self._bin_buffer_mmap = np.memmap(data_file_path(source_file), mode='r', order='C')
self._bin_buffer = memoryview(self._bin_buffer_mmap)
def __del__(self):
if self._bin_buffer_mmap is not None:
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
if self._index is not None:
del self._index
def __len__(self):
return self.total_length
def _next_file(self):
self._state += 1
if self._state >= self.max_state:
self._state = 0
# print_rank(f"next_file: {self._state}")
self._do_init(self._path, self._name, self._cache, self._state)
def __relative_idx(self, idx):
res = idx - self.history[self._state][0]
return res
def __slice_item(self, start, stop):
ptr = self._index._pointers[self.__relative_idx(start)]
sizes = self._index._sizes[self.__relative_idx(start):self.__relative_idx(stop)]
offsets = list(accumulate(sizes))
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=sum(sizes), offset=ptr)
return np.split(np_array, offsets[:-1])
def __getitem__(self, idx):
if isinstance(idx, int):
while idx >= self.history[self._state][1] or idx < self.history[self._state][0]:
self._next_file()
ptr, size = self._index[self.__relative_idx(idx)]
return np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)
elif isinstance(idx, slice):
raise NotImplementedError()
@property
def sizes(self):
return self._index.sizes
def exists(self, path):
return (
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
)
================================================
FILE: data_utils/indexed_dataset.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# copied from fairseq/fairseq/data/indexed_dataset.py
# Removed IndexedRawTextDataset since it relied on Fairseq dictionary
# other slight modifications to remove fairseq dependencies
# Added document index to index file and made it accessible.
# An empty sentence no longer separates documents.
from functools import lru_cache
import os
import shutil
import struct
from itertools import accumulate
import numpy as np
import torch
def __best_fitting_dtype(vocab_size=None):
if vocab_size is not None and vocab_size < 65500:
return np.uint16
else:
return np.int32
def get_available_dataset_impl():
return ['lazy', 'cached', 'mmap']
def infer_dataset_impl(path):
if IndexedDataset.exists(path):
with open(index_file_path(path), 'rb') as f:
magic = f.read(8)
if magic == IndexedDataset._HDR_MAGIC:
return 'cached'
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
return 'mmap'
else:
return None
else:
print(f"Dataset does not exist: {path}")
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
return None
def make_builder(out_file, impl, dtype):
if impl == 'mmap':
return MMapIndexedDatasetBuilder(out_file, dtype=dtype)
else:
return IndexedDatasetBuilder(out_file)
def make_dataset(path, impl, skip_warmup=False):
if not IndexedDataset.exists(path):
print(f"Dataset does not exist: {path}")
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
return None
if impl == 'infer':
impl = infer_dataset_impl(path)
if impl == 'lazy' and IndexedDataset.exists(path):
return IndexedDataset(path)
elif impl == 'cached' and IndexedDataset.exists(path):
return IndexedCachedDataset(path)
elif impl == 'mmap' and MMapIndexedDataset.exists(path):
return MMapIndexedDataset(path, skip_warmup)
print(f"Unknown dataset implementation: {impl}")
return None
def dataset_exists(path, impl):
if impl == 'mmap':
return MMapIndexedDataset.exists(path)
else:
return IndexedDataset.exists(path)
def read_longs(f, n):
a = np.empty(n, dtype=np.int64)
f.readinto(a)
return a
def write_longs(f, a):
f.write(np.array(a, dtype=np.int64))
dtypes = {
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
5: np.int64,
6: np.float32,
7: np.double,
8: np.uint16,
9: np.uint32
}
def code(dtype):
for k in dtypes.keys():
if dtypes[k] == dtype:
return k
raise ValueError(dtype)
def index_file_path(prefix_path):
return prefix_path + '.idx'
def data_file_path(prefix_path):
return prefix_path + '.bin'
def create_doc_idx(sizes):
doc_idx = [0]
for i, s in enumerate(sizes):
if s == 0:
doc_idx.append(i + 1)
return doc_idx
class IndexedDataset(torch.utils.data.Dataset):
"""Loader for IndexedDataset"""
_HDR_MAGIC = b'TNTIDX\x00\x00'
def __init__(self, path):
super().__init__()
self.path = path
self.data_file = None
self.read_index(path)
def read_index(self, path):
with open(index_file_path(path), 'rb') as f:
magic = f.read(8)
assert magic == self._HDR_MAGIC, (
'Index file doesn\'t match expected format. '
'Make sure that --dataset-impl is configured properly.'
)
version = f.read(8)
assert struct.unpack('<Q', version) == (1,)
code, self.element_size = struct.unpack('<QQ', f.read(16))
self.dtype = dtypes[code]
self._len, self.s = struct.unpack('<QQ', f.read(16))
self.doc_count = struct.unpack('<Q', f.read(8))
self.dim_offsets = read_longs(f, self._len + 1)
self.data_offsets = read_longs(f, self._len + 1)
self.sizes = read_longs(f, self.s)
self.doc_idx = read_longs(f, self.doc_count)
def read_data(self, path):
self.data_file = open(data_file_path(path), 'rb', buffering=0)
def check_index(self, i):
if i < 0 or i >= self._len:
raise IndexError('index out of range')
def __del__(self):
if self.data_file:
self.data_file.close()
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if not self.data_file:
self.read_data(self.path)
if isinstance(idx, int):
i = idx
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
return a
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError("Slices into indexed_dataset must be contiguous")
sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]]
size = sum(sizes)
a = np.empty(size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[start] * self.element_size)
self.data_file.readinto(a)
offsets = list(accumulate(sizes))
sents = np.split(a, offsets[:-1])
return sents
def __len__(self):
return self._len
def num_tokens(self, index):
return self.sizes[index]
def size(self, index):
return self.sizes[index]
@staticmethod
def exists(path):
return (
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
)
@property
def supports_prefetch(self):
return False # avoid prefetching to save memory
class IndexedCachedDataset(IndexedDataset):
def __init__(self, path):
super().__init__(path)
self.cache = None
self.cache_index = {}
@property
def supports_prefetch(self):
return True
def prefetch(self, indices):
if all(i in self.cache_index for i in indices):
return
if not self.data_file:
self.read_data(self.path)
indices = sorted(set(indices))
total_size = 0
for i in indices:
total_size += self.data_offsets[i + 1] - self.data_offsets[i]
self.cache = np.empty(total_size, dtype=self.dtype)
ptx = 0
self.cache_index.clear()
for i in indices:
self.cache_index[i] = ptx
size = self.data_offsets[i + 1] - self.data_offsets[i]
a = self.cache[ptx: ptx + size]
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
ptx += size
if self.data_file:
# close and delete data file after prefetch so we can pickle
self.data_file.close()
self.data_file = None
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if isinstance(idx, int):
i = idx
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
ptx = self.cache_index[i]
np.copyto(a, self.cache[ptx: ptx + a.size])
return a
elif isinstance(idx, slice):
# Hack just to make this work, can optimizer later if necessary
sents = []
for i in range(*idx.indices(len(self))):
sents.append(self[i])
return sents
class IndexedDatasetBuilder(object):
element_sizes = {
np.uint8: 1,
np.int8: 1,
np.int16: 2,
np.int32: 4,
np.int64: 8,
np.float32: 4,
np.double: 8
}
def __init__(self, out_file, dtype=np.int32):
self.out_file = open(out_file, 'wb')
self.dtype = dtype
self.data_offsets = [0]
self.dim_offsets = [0]
self.sizes = []
self.element_size = self.element_sizes[self.dtype]
self.doc_idx = [0]
def add_item(self, tensor):
bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype))
self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
for s in tensor.size():
self.sizes.append(s)
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
def end_document(self):
self.doc_idx.append(len(self.sizes))
def merge_file_(self, another_file):
index = IndexedDataset(another_file)
assert index.dtype == self.dtype
begin = self.data_offsets[-1]
for offset in index.data_offsets[1:]:
self.data_offsets.append(begin + offset)
self.sizes.extend(index.sizes)
begin = self.dim_offsets[-1]
for dim_offset in index.dim_offsets[1:]:
self.dim_offsets.append(begin + dim_offset)
with open(data_file_path(another_file), 'rb') as f:
while True:
data = f.read(1024)
if data:
self.out_file.write(data)
else:
break
def finalize(self, index_file):
self.out_file.close()
index = open(index_file, 'wb')
index.write(b'TNTIDX\x00\x00')
index.write(struct.pack('<Q', 1))
index.write(struct.pack('<QQ', code(self.dtype), self.element_size))
index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes)))
index.write(struct.pack('<Q', len(self.doc_idx)))
write_longs(index, self.dim_offsets)
write_longs(index, self.data_offsets)
write_longs(index, self.sizes)
write_longs(index, self.doc_idx)
index.close()
def _warmup_mmap_file(path):
with open(path, 'rb') as stream:
while stream.read(100 * 1024 * 1024):
pass
class MMapIndexedDataset(torch.utils.data.Dataset):
class Index(object):
_HDR_MAGIC = b'MMIDIDX\x00\x00'
@classmethod
def writer(cls, path, dtype):
class _Writer(object):
def __enter__(self):
self._file = open(path, 'wb')
self._file.write(cls._HDR_MAGIC)
self._file.write(struct.pack('<Q', 1))
self._file.write(struct.pack('<B', code(dtype)))
return self
@staticmethod
def _get_pointers(sizes):
dtype_size = dtype().itemsize
address = 0
pointers = []
for size in sizes:
pointers.append(address)
address += size * dtype_size
return pointers
def write(self, sizes, doc_idx):
pointers = self._get_pointers(sizes)
self._file.write(struct.pack('<Q', len(sizes)))
self._file.write(struct.pack('<Q', len(doc_idx)))
sizes = np.array(sizes, dtype=np.int32)
self._file.write(sizes.tobytes(order='C'))
del sizes
pointers = np.array(pointers, dtype=np.int64)
self._file.write(pointers.tobytes(order='C'))
del pointers
doc_idx = np.array(doc_idx, dtype=np.int64)
self._file.write(doc_idx.tobytes(order='C'))
def __exit__(self, exc_type, exc_val, exc_tb):
self._file.close()
return _Writer()
def __init__(self, path, skip_warmup=False):
with open(path, 'rb') as stream:
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, (
'Index file doesn\'t match expected format. '
'Make sure that --dataset-impl is configured properly.'
)
version = struct.unpack('<Q', stream.read(8))
assert (1,) == version
dtype_code, = struct.unpack('<B', stream.read(1))
self._dtype = dtypes[dtype_code]
self._dtype_size = self._dtype().itemsize
self._len = struct.unpack('<Q', stream.read(8))[0]
self._doc_count = struct.unpack('<Q', stream.read(8))[0]
offset = stream.tell()
if not skip_warmup:
print(" warming up index mmap file...")
_warmup_mmap_file(path)
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
self._bin_buffer = memoryview(self._bin_buffer_mmap)
print(" reading sizes...")
self._sizes = np.frombuffer(
self._bin_buffer,
dtype=np.int32,
count=self._len,
offset=offset)
print(" reading pointers...")
self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
offset=offset + self._sizes.nbytes)
print(" reading document index...")
self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
@property
def dtype(self):
return self._dtype
@property
def sizes(self):
return self._sizes
@property
def doc_idx(self):
return self._doc_idx
@lru_cache(maxsize=8)
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]
def __len__(self):
return self._len
def __init__(self, path, skip_warmup=False):
super().__init__()
self._path = None
self._index = None
self._bin_buffer = None
self._do_init(path, skip_warmup)
def __getstate__(self):
return self._path
def __setstate__(self, state):
self._do_init(state)
def _do_init(self, path, skip_warmup):
self._path = path
self._index = self.Index(index_file_path(self._path), skip_warmup)
if not skip_warmup:
print(" warming up data mmap file...")
_warmup_mmap_file(data_file_path(self._path))
print(" creating numpy buffer of mmap...")
self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
print(" creating memory view of numpy buffer...")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
del self._index
def __len__(self):
return len(self._index)
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if isinstance(idx, int):
assert idx < len(self._index), "Index {} out of range: {}".format(idx, len(self._index))
ptr, size = self._index[idx]
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
count=size, offset=ptr)
return np_array
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError("Slices into indexed_dataset must be contiguous")
ptr = self._index._pointers[start]
sizes = self._index._sizes[idx]
offsets = list(accumulate(sizes))
total_size = sum(sizes)
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
count=total_size, offset=ptr)
sents = np.split(np_array, offsets[:-1])
return sents
def get(self, idx, offset=0, length=None):
""" Retrieves a single item from the dataset with the option to only
return a portion of the item.
get(idx) is the same as [idx] but get() does not support slicing.
"""
ptr, size = self._index[idx]
if length is None:
length = size - offset
ptr += offset * np.dtype(self._index.dtype).itemsize
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
count=length, offset=ptr)
return np_array
@property
def sizes(self):
return self._index.sizes
# @property
# def doc_idx(self):
# return self._index.doc_idx
# def get_doc_idx(self):
# return self._index._doc_idx
# def set_doc_idx(self, doc_idx_):
# self._index._doc_idx = doc_idx_
@property
def supports_prefetch(self):
return False
@staticmethod
def exists(path):
return (
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
)
class MMapIndexedDatasetBuilder(object):
def __init__(self, out_file, dtype=np.int64):
self._data_file = open(out_file, 'wb')
self._dtype = dtype
self._sizes = []
self._doc_idx = [0]
def add_item(self, tensor):
np_array = np.array(tensor.numpy(), dtype=self._dtype)
self._data_file.write(np_array.tobytes(order='C'))
self._sizes.append(np_array.size)
def end_document(self):
self._doc_idx.append(len(self._sizes))
def merge_file_(self, another_file):
# Concatenate index
index = MMapIndexedDataset.Index(index_file_path(another_file))
assert index.dtype == self._dtype
for size in index.sizes:
self._sizes.append(size)
# Concatenate data
with open(data_file_path(another_file), 'rb') as f:
shutil.copyfileobj(f, self._data_file)
def finalize(self, index_file):
self._data_file.close()
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
index.write(self._sizes, self._doc_idx)
================================================
FILE: data_utils/lm_datasets.py
================================================
import random
import torch
import os
import json
import pickle
import numpy as np
from torch.utils.data import Dataset
from .distributed_indexed import DistributedMMapIndexedDataset
from torch.distributed import get_rank, get_world_size, barrier
from utils import print_rank
from utils import save_rank
class LMTrainDataset(Dataset):
def __init__(self, args, tokenizer, path, split, num, ratio, rng_sample: random.Random):
self.args = args
self.tokenizer = tokenizer
self.split = split
self.pad_id = self.tokenizer.eos_token_id
self.ratio = ratio
self.max_length = args.max_length
self.max_prompt_length = args.max_prompt_length
self.rng_sample = rng_sample
self.lm_ctx = DistributedMMapIndexedDataset(path, f"{split}", get_rank(), get_world_size())
if os.path.exists(os.path.join(path, f"{split}.jsonl")):
with open(os.path.join(path, f"{split}.jsonl")) as f:
self.raw = [json.loads(line) for line in f.readlines()]
self.answers = [x["output"] if isinstance(x["output"], list) else [x["output"]] for x in self.raw]
print_rank(len(self.lm_ctx))
if num == -1:
self.num = len(self.lm_ctx)
else:
self.num = num
print_rank(f"Num LM instances: {len(self.lm_ctx)}")
def __len__(self):
return self.num
def __getitem__(self, index):
return self._get_lm(index)
def _get_lm(self, index):
data = self.lm_ctx[index]
input_ids = data.astype(int)
return {
"input_ids": input_ids
}
def _process_lm(self, i, samp, model_data, no_model_data, gen_data):
input_ids = samp["input_ids"]
source_len = 1
prompt = None
if 65535 in input_ids:
source_len = np.where(input_ids==65535)[0][0]
prompt = input_ids[:source_len]
input_ids = np.concatenate([input_ids[:source_len], input_ids[source_len+1:]], axis=0)
input_ids = input_ids[:self.max_length]
input_len = len(input_ids)
model_data["input_ids"][i][:input_len-1] = torch.tensor(input_ids[:-1], dtype=torch.long)
model_data["attention_mask"][i][:input_len-1] = 1.0
if self.args.model_type in ["gpt2"]:
model_data["position_ids"][i][:input_len-1] = torch.arange(0, input_len-1, dtype=torch.long)
no_model_data["label"][i][:input_len-1] = torch.tensor(input_ids[1:], dtype=torch.long)
no_model_data["label"][i][:source_len-1] = -100
no_model_data["loss_mask"][i][:input_len-1] = 1.0
no_model_data["loss_mask"][i][:source_len-1] = 0
if prompt is not None:
gen_data["input_ids"][i][-len(prompt):] = torch.tensor(prompt, dtype=torch.long)
gen_data["attention_mask"][i][-len(prompt):] = 1.0
def move_to_device(self, model_data, no_model_data, gen_data, device):
for k in model_data:
model_data[k] = model_data[k].to(device)
for k in no_model_data:
no_model_data[k] = no_model_data[k].to(device)
for k in gen_data:
gen_data[k] = gen_data[k].to(device)
return model_data, no_model_data, gen_data
def collate(self, samples):
bs = len(samples)
max_length = self.max_length
model_data = {
"input_ids": torch.ones(bs, max_length, dtype=torch.long) * self.pad_id,
"attention_mask": torch.zeros(bs, max_length),
}
if self.args.model_type in ["gpt2"]:
model_data["position_ids"] = torch.zeros(bs, max_length, dtype=torch.long)
no_model_data = {
"label": torch.ones(bs, max_length, dtype=torch.long) * -100,
"loss_mask": torch.zeros(bs, max_length)
}
gen_data = {
"input_ids": torch.ones(bs, self.max_prompt_length, dtype=torch.long) * self.pad_id,
"attention_mask": torch.zeros(bs, self.max_prompt_length, dtype=torch.long),
}
for i, samp in enumerate(samples):
self._process_lm(i, samp, model_data, no_model_data, gen_data)
return model_data, no_model_data, gen_data
================================================
FILE: data_utils/prompt_datasets.py
================================================
import random
import torch
import os
from torch.utils.data import Dataset
from .distributed_indexed import DistributedMMapIndexedDataset
from torch.distributed import get_rank, get_world_size
from utils import print_rank
from tqdm import tqdm
import json
class PromptDataset(Dataset):
def __init__(self, args, tokenizer, split, data_path=None, num=-1):
super().__init__()
self.tokenizer = tokenizer
self.args = args
self.tokenizer = tokenizer
self.split = split
self.pad_id = self.tokenizer.eos_token_id
self.max_length = args.max_length
self.min_prompt_length = args.min_prompt_length
self.max_prompt_length = args.max_prompt_length
if args.bin_data:
self.data = DistributedMMapIndexedDataset(data_path, f"{split}", get_rank(), get_world_size())
elif args.json_data:
self.data, self.origin_data = self.load_data_json(data_path)
else:
# txt data
self.data = self.load_data_txt(data_path)
if os.path.exists(os.path.join(data_path, f"{self.split}_{self.args.model_type}.jsonl")):
with open(os.path.join(data_path, f"{self.split}_{self.args.model_type}.jsonl")) as f:
self.raw = [json.loads(line) for line in f.readlines()]
self.answers = [x["output"] if isinstance(x["output"], list) else [x["output"]] for x in self.raw]
elif os.path.exists(os.path.join(data_path, f"{split}.jsonl")):
with open(os.path.join(data_path, f"{split}.jsonl")) as f:
self.raw = [json.loads(line) for line in f.readlines()]
self.answers = [x["output"] if isinstance(x["output"], list) else [x["output"]] for x in self.raw]
else:
print_rank("WARNING: No answers exist")
self.label_map = {tokenizer.encode(x[0], add_special_tokens=False)[0]: x[0] for x in self.answers}
self.num = min(num, len(self.data)) if num > 0 else len(self.data)
print_rank(f"Num instances: {len(self.data)}")
def __len__(self):
return self.num
def load_data_json(self, data_path):
if os.path.exists(os.path.join(data_path, f"{self.split}_{self.args.model_type}.jsonl")):
data_path = os.path.join(data_path, f"{self.split}_{self.args.model_type}.jsonl")
else:
data_path = os.path.join(data_path, f"{self.split}.jsonl")
with open(data_path) as f:
lines = f.readlines()
data_origin = [json.loads(line) for line in lines]
data = []
print_rank("Loading Data")
for d in tqdm(data_origin, disable=(get_rank() != 0)):
prompt = d["prompt"].replace("<n>", "\n")
prompt_ids = self.tokenizer.encode(prompt)
output_ids = None
if "output" in d:
if isinstance(d["output"], list):
output_ids = self.tokenizer.encode(d["output"][0])
else:
output_ids = self.tokenizer.encode(d["output"])
data.append({
"prompt_ids": prompt_ids,
"output_ids": output_ids[:self.max_length - self.max_prompt_length]
})
print_rank("Load End")
return data, data_origin
def load_data_txt(self, data_path):
with open(os.path.join(data_path, f"{self.split}.txt")) as f:
lines = f.readlines()
data = []
print_rank("Loading Data")
for line in lines:
line = line.strip()
line = line.replace("<n>", "\n")
prompt = self.tokenizer.encode(line)
data.append(prompt)
print_rank("Load End")
return data
def verbalizer(self):
return self.label_map
def __getitem__(self, index: int):
data = self.data[index]
if self.args.bin_data:
data = data.astype(int)
elif self.args.json_data:
output_ids = data["output_ids"]
data = data["prompt_ids"]
prompt_length = self.max_prompt_length
prompt = data[:prompt_length]
rest = data[prompt_length:]
if self.args.json_data:
if output_ids is not None:
rest = output_ids
return index, prompt, rest
def collate(self, samples):
bs = len(samples)
max_prompt_length = self.max_prompt_length
max_rest_length = max([len(samp[2]) for samp in samples])
model_batch = {
"input_ids": torch.ones(bs, max_prompt_length, dtype=torch.long) * self.pad_id,
"attention_mask": torch.zeros(bs, max_prompt_length, dtype=torch.long),
# "position_ids": torch.zeros(bs, max_prompt_length, dtype=torch.long)
}
no_model_batch = {
"idx": torch.zeros(bs, dtype=torch.long),
"rest_ids": torch.ones(bs, max_rest_length, dtype=torch.long) * self.pad_id
}
for i, (idx, prompt, rest) in enumerate(samples):
# left padding
model_batch["input_ids"][i][-len(prompt):] = torch.tensor(prompt, dtype=torch.long)
model_batch["attention_mask"][i][-len(prompt):] = 1
# model_batch["position_ids"][i][-len(prompt):] = torch.arange(len(prompt))
no_model_batch["idx"][i] = idx
no_model_batch["rest_ids"][i][:len(rest)] = torch.tensor(rest, dtype=torch.long)
return model_batch, no_model_batch
def move_to_device(self, model_batch, no_model_batch, device):
for k in model_batch:
model_batch[k] = model_batch[k].to(device)
for k in no_model_batch:
no_model_batch[k] = no_model_batch[k].to(device)
return model_batch, no_model_batch
================================================
FILE: distillm/__init__.py
================================================
from .losses import forward_kl, reverse_kl, symmetric_kl, js_distance, tv_distance
from .losses import skewed_forward_kl, skewed_reverse_kl
from .sampler import SampleGenerator
from .buffer import ReplayBuffer
================================================
FILE: distillm/buffer.py
================================================
import random
import torch
import os
import json
import pickle
import numpy as np
from torch.utils.data import Dataset
from torch.distributed import get_rank, get_world_size, barrier
from utils import print_rank
from utils import save_rank
from collections import namedtuple, deque
class ReplayBuffer:
def __init__(self, args):
self.args = args
self.replay_memory = deque(maxlen=args.capacity)
self.bs = args.batch_size
if args.model_type in ["gpt2", "llama"]:
self.data = namedtuple("Generation", \
field_names=["input_ids", "attention_mask", "position_ids", "label", "loss_mask"])
else:
self.data = namedtuple("Generation", \
field_names=["input_ids", "attention_mask", "label", "loss_mask"])
def __len__(self):
return len(self.replay_memory)
def sample(self):
data = random.sample(self.replay_memory, k=self.bs)
input_ids = torch.stack([d.input_ids for d in data], dim=0)
attention_mask = torch.stack([d.attention_mask for d in data], dim=0)
label = torch.stack([d.label for d in data], dim=0)
loss_mask = torch.stack([d.loss_mask for d in data], dim=0)
if self.args.model_type in ["gpt2", "llama"]:
position_ids = torch.stack([d.position_ids for d in data], dim=0)
model_data = {
"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids
}
else:
model_data = {
"input_ids": input_ids, "attention_mask": attention_mask
}
no_model_data = {
"label": label, "loss_mask": loss_mask
}
return model_data, no_model_data
def move_to_device(self, model_data, no_model_data, device):
for k in model_data:
model_data[k] = model_data[k].to(device)
for k in no_model_data:
no_model_data[k] = no_model_data[k].to(device)
return model_data, no_model_data
def move_to_memory(self, model_data, no_model_data):
device = torch.device("cpu")
model_data_cpu, no_model_data_cpu = {}, {}
for k in model_data:
model_data_cpu[k] = model_data[k].to(device)
for k in no_model_data:
no_model_data_cpu[k] = no_model_data[k].to(device)
for idx in range(model_data_cpu["input_ids"].size(0)):
if self.args.model_type in ["gpt2", "llama"]:
e = self.data(model_data_cpu["input_ids"][idx], model_data_cpu["attention_mask"][idx], model_data_cpu["position_ids"][idx],
no_model_data_cpu["label"][idx], no_model_data_cpu["loss_mask"][idx])
else:
e = self.data(model_data_cpu["input_ids"][idx], model_data_cpu["attention_mask"][idx],
no_model_data_cpu["label"][idx], no_model_data_cpu["loss_mask"][idx])
self.replay_memory.append(e)
================================================
FILE: distillm/losses.py
================================================
import torch
import torch.nn.functional as F
def forward_kl(logits, teacher_logits, no_model_batch):
teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
inf_mask = torch.isinf(logits)
student_logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
prod_probs = torch.masked_fill(teacher_probs * student_logprobs, inf_mask, 0)
x = torch.sum(prod_probs, dim=-1).view(-1)
mask = (no_model_batch["label"] != -100).int()
distil_loss = -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
return distil_loss
def reverse_kl(logits, teacher_logits, no_model_batch):
student_probs = F.softmax(logits, dim=-1, dtype=torch.float32)
student_logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
teacher_logprobs = F.log_softmax(teacher_logits, dim=-1, dtype=torch.float32)
inf_mask = torch.isinf(teacher_logits) | torch.isinf(logits)
prod_probs = torch.masked_fill(student_probs * teacher_logprobs, inf_mask, 0)
prod_probs -= torch.masked_fill(student_probs * student_logprobs, inf_mask, 0)
x = torch.sum(prod_probs, dim=-1).view(-1)
mask = (no_model_batch["label"] != -100).int()
distil_loss = -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
return distil_loss
def symmetric_kl(logits, teacher_logits, no_model_batch, lam=0.9):
for_kl = forward_kl(logits, teacher_logits, no_model_batch)
rev_kl = reverse_kl(logits, teacher_logits, no_model_batch)
distil_loss = (1-lam) * for_kl + lam * rev_kl
return distil_loss
def js_distance(logits, teacher_logits, no_model_batch, lam=0.9):
teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
student_probs = F.softmax(logits, dim=-1, dtype=torch.float32)
mixed_probs = (1-lam) * teacher_probs + lam * student_probs
teacher_logprobs = F.log_softmax(teacher_logits, dim=-1, dtype=torch.float32)
student_logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
mixed_logprobs = torch.log(mixed_probs)
mask = (no_model_batch["label"] != -100).int()
inf_mask = torch.isinf(logits) | torch.isinf(teacher_logits)
prod_probs = torch.masked_fill(student_probs * mixed_logprobs, inf_mask, 0)
prod_probs -= torch.masked_fill(student_probs * student_logprobs, inf_mask, 0)
x = torch.sum(prod_probs, dim=-1).view(-1)
distil_loss = lam * -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
prod_probs = torch.masked_fill(teacher_probs * mixed_logprobs, inf_mask, 0)
prod_probs -= torch.masked_fill(teacher_probs * teacher_logprobs, inf_mask, 0)
x = torch.sum(prod_probs, dim=-1).view(-1)
distil_loss += (1-lam) * -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
return distil_loss
def tv_distance(logits, teacher_logits, no_model_batch):
teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
student_probs = F.softmax(logits, dim=-1, dtype=torch.float32)
mask = (no_model_batch["label"] != -100).int()
inf_mask = torch.isinf(logits) | torch.isinf(teacher_logits)
prod_probs = 0.5 * torch.masked_fill(torch.abs(teacher_probs - student_probs), inf_mask, 0)
x = torch.sum(prod_probs, dim=-1).view(-1)
distil_loss = torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
return distil_loss
def skewed_forward_kl(logits, teacher_logits, no_model_batch, lam=0.1):
teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
student_probs = F.softmax(logits, dim=-1, dtype=torch.float32)
mixed_probs = lam * teacher_probs + (1-lam) * student_probs
mixed_logprobs = torch.log(mixed_probs)
mask = (no_model_batch["label"] != -100).int()
inf_mask = torch.isinf(logits) | torch.isinf(teacher_logits)
prod_probs = torch.masked_fill(teacher_probs * mixed_logprobs, inf_mask, 0)
x = torch.sum(prod_probs, dim=-1).view(-1)
distil_loss = -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
return distil_loss
def skewed_reverse_kl(logits, teacher_logits, no_model_batch, lam=0.1):
teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
student_probs = F.softmax(logits, dim=-1, dtype=torch.float32)
mixed_probs = (1-lam) * teacher_probs + lam * student_probs
student_logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
mixed_logprobs = torch.log(mixed_probs)
mask = (no_model_batch["label"] != -100).int()
inf_mask = torch.isinf(logits) | torch.isinf(teacher_logits)
prod_probs = torch.masked_fill(student_probs * mixed_logprobs, inf_mask, 0)
prod_probs -= torch.masked_fill(student_probs * student_logprobs, inf_mask, 0)
x = torch.sum(prod_probs, dim=-1).view(-1)
distil_loss = -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
return distil_loss
================================================
FILE: distillm/sampler.py
================================================
import torch
import os
from transformers import GenerationConfig
class SampleGenerator():
def __init__(self, args, tokenizer):
self.args = args
self.tokenizer = tokenizer
self.max_new_token = self.args.max_length - self.args.max_prompt_length
self.pad_id = tokenizer.pad_token_id
self.generation_config = GenerationConfig(
do_sample=args.do_sample,
top_p=args.gen_top_p,
top_k=args.top_k,
temperature=args.temperature,
repetition_penalty=args.repetition_penalty,
max_length=args.max_length,
min_length=None,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True,
output_scores=False
)
def run_sample(self, model, gen_data):
bs = gen_data["input_ids"].size(0)
results = {
"input_ids": torch.ones(bs, self.args.max_length, dtype=torch.long, device=gen_data["input_ids"].device) * self.pad_id,
"attention_mask": torch.zeros(bs, self.args.max_length, dtype=torch.float, device=gen_data["input_ids"].device),
"position_ids": torch.zeros(bs, self.args.max_length, dtype=torch.long, device=gen_data["input_ids"].device),
"no_model_batch": torch.ones(bs, self.args.max_length, dtype=torch.long, device=gen_data["input_ids"].device) * -100,
}
model.eval()
with torch.no_grad():
gen_out = model.generate(
**gen_data,
generation_config=self.generation_config,
max_new_tokens=self.max_new_token,
)
full_ids = gen_out.sequences
input_ids = full_ids[:, :gen_data["input_ids"].size(1)]
response_ids = full_ids[:, gen_data["input_ids"].size(1):]
for i in range(len(input_ids)):
result_id = torch.cat(
(input_ids[i][input_ids[i] != self.pad_id],
response_ids[i][response_ids[i] != self.pad_id]),
)
input_id = input_ids[i][input_ids[i] != self.pad_id]
response_id = response_ids[i][response_ids[i] != self.pad_id]
results["input_ids"][i, :len(result_id)] = result_id
results["position_ids"][i, :len(result_id)] = torch.arange(len(result_id))
results["no_model_batch"][i, len(input_id):len(result_id)] = response_id
results["attention_mask"] = torch.where(results["input_ids"] != self.pad_id, 1, 0)
results["attention_mask"] = results["attention_mask"].float()
results["no_model_batch"] = results["no_model_batch"].long()
return results
================================================
FILE: evaluate.py
================================================
import time
import os
import torch
import torch.distributed as dist
import deepspeed
import json
from arguments import get_args
from utils import initialize, print_args
from utils import print_rank
from utils import save_rank
from utils import get_tokenizer, get_model
from evaluate_main import evaluate_main, prepare_dataset_main
torch.set_num_threads(4)
def setup_model(args, ds_config, device):
# get the model
model = get_model(args, device)
# get the optimizer and lr_scheduler
optimizer, lr_scheduler = None, None
model, _, _, _ = deepspeed.initialize(
model=model,
optimizer=optimizer,
args=args,
lr_scheduler=lr_scheduler,
mpu=None,
config_params=ds_config
)
# get the memory usage
print_rank("Model mem\n", torch.cuda.memory_summary())
return model
def main():
torch.backends.cudnn.enabled = False
args = get_args()
initialize(args)
if dist.get_rank() == 0:
print_args(args)
with open(os.path.join(args.save, "args.json"), "w") as f:
json.dump(vars(args), f)
device = torch.cuda.current_device()
cur_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
save_rank("\n\n" + "="*30 + f" EXP at {cur_time} " + "="*30, os.path.join(args.save, "log.txt"))
print("OK")
with open(args.deepspeed_config, "r") as f:
ds_config = json.load(f)
ds_config["gradient_accumulation_steps"] = args.gradient_accumulation_steps
ds_config["train_micro_batch_size_per_gpu"] = args.batch_size
ds_config["gradient_clipping"] = args.clip_grad
ds_config["steps_per_print"] = args.gradient_accumulation_steps
if not args.do_train:
ds_config["zero_optimization"]["stage"] = 0
args.fp32 = not ds_config["fp16"]["enabled"]
args.deepspeed_config = None
# get the tokenizer
tokenizer = get_tokenizer(args)
if args.type == "eval_main":
dataset = prepare_dataset_main(
args,
tokenizer,
)
else:
raise NotImplementedError
model = setup_model(args, ds_config, device)
if args.type == "eval_main":
evaluate_main(args, tokenizer, model, dataset["test"], "test", 0, device)
else:
raise NotImplementedError
if __name__ == "__main__":
main()
================================================
FILE: evaluate_main.py
================================================
from data_utils.prompt_datasets import PromptDataset
from transformers import GenerationConfig
import os
import nltk
nltk.download("punkt")
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import json
from utils import print_rank, save_rank, all_gather
from rouge_metric import compute_metrics
torch.set_num_threads(4)
def prepare_dataset_main(args, tokenizer):
data = {}
data["test"] = PromptDataset(args, tokenizer, "valid", args.data_dir, args.dev_num)
return data
def run_model(args, tokenizer, model, dataset: PromptDataset, epoch, device):
collate_fn = dataset.collate
dp_world_size = dist.get_world_size()
dp_rank = dist.get_rank()
dp_group = None
sampler = DistributedSampler(dataset, shuffle=False, drop_last=False, rank=dp_rank, num_replicas=dp_world_size)
dataloader = DataLoader(
dataset, sampler=sampler, batch_size=args.eval_batch_size, num_workers=args.num_workers, collate_fn=collate_fn)
model.eval()
all_query_ids = []
all_response_ids = []
all_lm_losses = []
generation_config = GenerationConfig (
do_sample=args.do_sample,
top_p=args.top_p,
top_k=args.top_k,
temperature=args.temperature,
no_repeat_ngram_size=args.no_repeat_ngram_size,
repetition_penalty=args.repetition_penalty,
max_length=args.max_length,
min_length=None,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
return_dict_in_generate=True,
output_scores=True
)
with torch.no_grad():
for it, (model_batch, no_model_batch) in enumerate(tqdm(dataloader, desc=f"Evaluating {args.data_names} ", disable=(dist.get_rank() != 0))):
if it == 0:
print_rank("############### Example ###############")
print_rank(tokenizer.decode(model_batch["input_ids"][0], skip_special_tokens=True))
print_rank("############### End ###############")
dataset.move_to_device(model_batch, no_model_batch, device)
all_ids = torch.cat([model_batch["input_ids"], no_model_batch["rest_ids"]], dim=-1)
input_ids = all_ids[:, :-1]
attention_mask = (input_ids != tokenizer.pad_token_id).long()
label_ids = all_ids[:, 1:]
label_ids = torch.masked_fill(label_ids, label_ids==tokenizer.pad_token_id, -100)
label_ids[:, :model_batch["input_ids"].size(1)-1] = -100
if args.model_type in ["gpt2"]:
position_ids = (torch.cumsum(attention_mask, dim=-1) - 1) * attention_mask
out = model(input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask, return_dict=True)
else:
out = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
logits = out.logits
loss_mask = (label_ids != -100).float()
loss_func = nn.CrossEntropyLoss(reduction="none")
lm_loss = loss_func(logits.view(-1, logits.size(-1)), label_ids.view(-1)).view(label_ids.size())
lm_loss = torch.sum(lm_loss * loss_mask, -1) / torch.sum(loss_mask, -1)
all_lm_losses.append(lm_loss)
query_ids = model_batch["input_ids"]
max_new_tokens = args.max_length - query_ids.size(1)
gen_out = model.generate(
**model_batch,
generation_config=generation_config,
max_new_tokens=max_new_tokens
)
full_ids = gen_out.sequences
response_ids = full_ids[:, query_ids.size(1):] # remove prompt (may include start token)
query_ids = F.pad(query_ids, (args.max_prompt_length-query_ids.size(1), 0, 0, 0), value=tokenizer.pad_token_id)
response_ids = F.pad(response_ids, (0, args.max_length-args.max_prompt_length-response_ids.size(1), 0, 0), value=tokenizer.pad_token_id)
all_query_ids.append(query_ids)
all_response_ids.append(response_ids)
all_lm_losses = torch.cat(all_lm_losses)
mean_lm_loss = all_lm_losses.mean()
dist.all_reduce(mean_lm_loss, dist.ReduceOp.SUM, group=dp_group)
mean_lm_loss = mean_lm_loss.item() / dp_world_size
all_query_ids = torch.cat(all_query_ids)
all_query_ids = all_gather(all_query_ids, dim=1, group=dp_group, world_size=dp_world_size, op="stack")
all_query_ids = all_query_ids.view(-1, all_query_ids.size(-1))
all_query_ids = all_query_ids[:len(dataset)]
all_response_ids = torch.cat(all_response_ids)
all_response_ids = all_gather(all_response_ids, dim=1, group=dp_group, world_size=dp_world_size, op="stack")
all_response_ids = all_response_ids.view(-1, all_response_ids.size(-1))
all_response_ids = all_response_ids[:len(dataset)]
return (
mean_lm_loss,
all_query_ids,
all_response_ids)
def evaluate_main(args, tokenizer, model, dataset: PromptDataset, split, epoch, device):
lm_loss, query_ids, response_ids = run_model(args, tokenizer, model, dataset, epoch, device)
query_strs = tokenizer.batch_decode(query_ids, skip_special_tokens=True)
response_strs = tokenizer.batch_decode(response_ids, skip_special_tokens=True)
with open(os.path.join(args.save, "preds.txt"), "w") as f:
for q, r in zip(query_strs, response_strs):
f.write(q.replace("\n", "<n>") + "\t\t" + r.replace("\n", "<n>") + "\n")
all_preds = [[]]
for q, r in zip(query_strs, response_strs):
all_preds[0].append((q, q + r))
torch.save(all_preds, os.path.join(args.save, "preds.pt"))
all_responses = []
with open(os.path.join(args.save, "answers.jsonl"), "w") as f:
for p in all_preds[0]:
q, r = p
r = r[len(q):]
idx = r.find("<|endoftext|>")
if idx >= 0:
r = r[:idx]
f.write(json.dumps({
"text": r.replace("<n>", "\n").strip()
}) + "\n")
all_responses.append(r.replace("<n>", "\n").strip())
gen_res = compute_metrics(all_responses, dataset.answers)
mean_gen_length = np.mean([len(tokenizer.encode(s)) for s in response_strs])
log_str = f"{split} | name: {args.data_names} | {gen_res} | lm_loss {round(lm_loss, 4)} | avg. gen lenth: {mean_gen_length}"
print_rank(log_str)
save_rank(log_str, os.path.join(args.save, "log.txt"))
================================================
FILE: finetune.py
================================================
import time
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from torch.optim import AdamW
import deepspeed
import random
import json
from tqdm import tqdm
import math
import datetime
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoConfig,
GenerationConfig)
from transformers import get_constant_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup
from torch.optim.lr_scheduler import CosineAnnealingLR
from arguments import get_args
from data_utils.lm_datasets import LMTrainDataset
from utils import get_optimizer_params, get_optimizer_params_peft, print_args, initialize
from utils import print_rank, get_rank
from utils import save_rank
from utils import all_gather
from utils import load_parallel, save_parallel
from utils import get_tokenizer, get_model
from distillm import forward_kl, reverse_kl, js_distance, tv_distance
from distillm import skewed_forward_kl, skewed_reverse_kl
from distillm import SampleGenerator, ReplayBuffer
from rouge_metric import compute_metrics
from peft import PeftModel
torch.set_num_threads(4)
def get_teacher_model(args, device):
config = AutoConfig.from_pretrained(args.teacher_model_path)
if args.model_parallel:
raise NotImplementedError
else:
config.is_model_parallel = False
try: model = AutoModelForCausalLM.from_pretrained(args.teacher_model_path, config=config, device_map={"": device}, torch_dtype=torch.float16)
except:
model = AutoModelForCausalLM.from_pretrained(args.teacher_model_path, config=config, device_map={"": device}, torch_dtype=torch.float32)
model = model.half()
if args.peft is not None and args.teacher_peft_path is not None:
if args.peft == "lora":
model = PeftModel.from_pretrained(model, args.teacher_peft_path)
model = model.merge_and_unload()
else:
raise NotImplementedError
else:
if dist.get_rank() == 0:
print(' > number of parameters: {}'.format(
sum([p.nelement() for p in model.parameters()])), flush=True)
model.eval()
return model
def get_optimizer(args, model):
"""Set up the optimizer."""
# Build parameter groups (weight decay and non-decay).
while isinstance(model, DDP):
model = model.module
if args.peft is not None:
param_groups = get_optimizer_params_peft(args, model)
else:
param_groups = get_optimizer_params(args, model)
# Use AdamW.
optimizer = AdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay)
print_rank(f'Optimizer = {optimizer.__class__.__name__}')
return optimizer
def get_learning_rate_scheduler(args, optimizer):
if args.total_iters is None:
args.total_iters = args.train_iters_per_epoch * args.epochs
if args.lr_decay_style == "constant":
lr_scheduler = get_constant_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_iters)
elif args.lr_decay_style == "cosine":
lr_scheduler = CosineAnnealingLR(
optimizer,
T_max=args.total_iters,
eta_min=args.lr_min)
elif args.lr_decay_style == "noam":
lr_scheduler = get_polynomial_decay_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_iters,
num_training_steps=args.total_iters,
power=0.5)
else:
raise ValueError(f"lr_scheduler of type {args.lr_decay_style} is not supported yet.")
return lr_scheduler
def setup_model_and_optimizer(args, ds_config, device, set_optim=True):
# get the model
model = get_model(args, device)
# get the optimizer and lr_scheduler
if set_optim:
optimizer = get_optimizer(args, model)
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
else:
optimizer, lr_scheduler = None, None
model, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model,
optimizer=optimizer,
args=args,
lr_scheduler=lr_scheduler,
mpu=None,
config_params=ds_config
)
# get the memory usage
print_rank("Model mem\n", torch.cuda.memory_summary())
return model, optimizer, lr_scheduler
def prepare_dataset(args, tokenizer):
data = {}
rng_sample = random.Random(args.seed)
if args.do_train:
data["train"] = LMTrainDataset(args, tokenizer, args.data_dir, "train", args.train_num, args.train_ratio, rng_sample)
print_rank("train num", len(data["train"]))
data["dev"] = LMTrainDataset(args, tokenizer, args.data_dir, "valid", args.dev_num, args.dev_ratio, rng_sample)
elif args.do_eval:
data["test"] = LMTrainDataset(args, tokenizer, args.data_dir, "valid", args.dev_num, args.dev_ratio, rng_sample)
else:
raise ValueError("Do train and do eval must set one")
# pre-trained dataset
if args.do_train and args.lm_data_dir is not None:
data["pt_train"] = LMTrainDataset(args, tokenizer, args.lm_data_dir, "train", args.train_num, args.train_ratio, rng_sample)
print_rank("train num", len(data["pt_train"]))
return data
def pt_loss(args, model, model_batch, no_model_batch):
loss_mask = (no_model_batch["label"] != -100).int()
outputs = model(**model_batch, return_dict=True, use_cache=False)
logits = outputs.logits
loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
lm_loss = loss_fn(logits.view(-1, logits.size(-1)), no_model_batch["label"].view(-1))
return lm_loss
def get_distil_loss(args, tokenizer, model, teacher_model, model_batch, no_model_batch, logits):
with torch.no_grad():
teacher_model.eval()
teacher_outputs = teacher_model(**model_batch, use_cache=False)
teacher_logits = teacher_outputs.logits
if args.model_parallel:
raise NotImplementedError
else:
if "sfkl" in args.type:
distil_loss = skewed_forward_kl(logits, teacher_logits, no_model_batch, lam=args.skew_alpha)
elif "srkl" in args.type:
distil_loss = skewed_reverse_kl(logits, teacher_logits, no_model_batch, lam=args.skew_alpha)
elif "jsd" in args.type:
distil_loss = js_distance(logits, teacher_logits, no_model_batch)
elif "tvd" in args.type:
distil_loss = tv_distance(logits, teacher_logits, no_model_batch)
elif "fkl" in args.type or args.type == "kd":
distil_loss = forward_kl(logits, teacher_logits, no_model_batch)
elif "rkl" in args.type:
distil_loss = reverse_kl(logits, teacher_logits, no_model_batch)
else:
raise NotImplementedError
return distil_loss
def get_teacher_lm_loss(args, tokenizer, model, teacher_model, model_batch):
with torch.no_grad():
t_gen_out = teacher_model.generate(
**model_batch,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
max_length=args.max_length,
top_k=0,
top_p=1,
temperature=1.0,
do_sample=True,
return_dict_in_generate=True,
output_scores=False)
full_ids = t_gen_out.sequences
input_ids = full_ids[:, :-1]
mask = (input_ids != tokenizer.pad_token_id).long()
labels = full_ids[:, 1:]
labels = torch.masked_fill(labels, mask==0, -100)
labels[:, :model_batch["input_ids"].size(1)-1] = -100
loss_mask = (labels != -100).float()
new_batch = {
"input_ids": input_ids,
"attention_mask": mask,
}
if args.model_type in ["gpt2"]:
position_ids = torch.cumsum(mask, dim=-1) - 1
position_ids = torch.masked_fill(position_ids, mask==0, 0)
new_batch["position_ids"] = position_ids
loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
outputs = model(**new_batch, return_dict=True, use_cache=False)
logits = outputs.logits
lm_loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
return lm_loss
def finetune(args, tokenizer: AutoTokenizer, model: deepspeed.DeepSpeedEngine, optimizer: AdamW, lr_scheduler, dataset, device, teacher_model=None):
print_rank("Start Fine-tuning")
# print_inspect(model, '*')
if args.model_parallel:
raise NotImplementedError
else:
dp_world_size = dist.get_world_size()
dp_rank = dist.get_rank()
dp_group = None
loss_func = nn.CrossEntropyLoss()
sampler = DistributedSampler(dataset["train"], shuffle=True, drop_last=True, rank=dp_rank, num_replicas=dp_world_size)
train_dataloader = DataLoader(
dataset['train'], sampler=sampler, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=dataset["train"].collate)
if "pt_train" in dataset:
pt_sampler = DistributedSampler(dataset["pt_train"], shuffle=True, drop_last=True, rank=dp_rank, num_replicas=dp_world_size)
pt_train_dataloader = DataLoader(
dataset['pt_train'], sampler=pt_sampler, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=dataset["pt_train"].collate)
pt_train_iter = iter(pt_train_dataloader)
student_generator = SampleGenerator(args, tokenizer)
step, global_step = 1, 1
total_loss, total_distil_loss, total_time = 0.0, 0.0, 0.0
adaptive_threshold = args.init_threshold if "adaptive" in args.type else None
prev_avg_loss = evaluate(args, tokenizer, model, dataset["dev"], "dev", 0, device, adaptive_threshold)
replay_buffer = ReplayBuffer(args)
for epoch in range(args.epochs):
sampler.set_epoch(epoch)
model.train()
for it, (model_batch, no_model_batch, gen_data) in enumerate(train_dataloader):
dataset["train"].move_to_device(model_batch, no_model_batch, gen_data, device)
if args.lm_data_dir is not None:
try:
pt_model_batch, pt_no_model_batch, pt_gen_data = next(pt_train_iter)
# pt_model_batch, pt_no_model_batch, pt_gen_data = pt_train_iter.next()
except:
pt_train_iter = iter(pt_train_dataloader)
# pt_model_batch, pt_no_model_batch, pt_gen_data = pt_train_iter.next()
pt_model_batch, pt_no_model_batch, pt_gen_data = next(pt_train_iter)
dataset["pt_train"].move_to_device(pt_model_batch, pt_no_model_batch, pt_gen_data, device)
torch.cuda.synchronize()
st_time = time.time()
# # sampling ratio:
samp_threshold = adaptive_threshold * (1 - global_step / args.total_iters)
if "adaptive" in args.type:
if args.replay_ratio == "constant":
samp_threshold = adaptive_threshold * 0.5
elif args.replay_ratio == "increasing":
samp_threshold = adaptive_threshold * global_step / args.total_iters
else:
samp_threshold = adaptive_threshold * (1 - global_step / args.total_iters)
# data generation
if args.student_gen:
r = np.random.uniform(0, 1)
if "mixed" in args.type and r < args.mixed_alpha:
model_batch = student_generator.run_sample(model, gen_data)
no_model_batch["label"] = model_batch.pop("no_model_batch")
replay_buffer.move_to_memory(model_batch, no_model_batch)
model_batch, no_model_batch = replay_buffer.sample()
model_batch, no_model_batch = replay_buffer.move_to_device(model_batch, no_model_batch, device)
elif "adaptive" in args.type and (r < samp_threshold or (r < adaptive_threshold and len(replay_buffer) < args.capacity)):
model_batch = student_generator.run_sample(model, gen_data)
no_model_batch["label"] = model_batch.pop("no_model_batch")
if args.model_type in ["opt"]:
model_batch.pop('position_ids')
replay_buffer.move_to_memory(model_batch, no_model_batch)
elif "adaptive" in args.type and r < adaptive_threshold:
model_batch, no_model_batch = replay_buffer.sample()
model_batch, no_model_batch = replay_buffer.move_to_device(model_batch, no_model_batch, device)
model.train()
outputs = model(**model_batch, use_cache=False)
logits = outputs.logits
if args.model_parallel:
raise NotImplementedError
else:
lm_loss = loss_func(logits.float().view(-1, logits.shape[-1]), no_model_batch["label"].view(-1))
if teacher_model is not None:
distil_loss = get_distil_loss(args, tokenizer, model, teacher_model, model_batch, no_model_batch, logits)
loss = (1 - args.kd_ratio) * lm_loss + args.kd_ratio * distil_loss
else:
loss = lm_loss
if args.lm_data_dir is not None:
assert args.lm_coef is not None
loss += args.lm_coef * pt_loss(args, model, pt_model_batch, pt_no_model_batch)
model.backward(loss)
model.step()
dist.all_reduce(loss, dist.ReduceOp.SUM, group=dp_group)
global_loss = loss.item() / dp_world_size
global_distil_loss = 0
if teacher_model is not None:
dist.all_reduce(distil_loss, dist.ReduceOp.SUM, group=dp_group)
global_distil_loss = distil_loss.item() / dp_world_size
total_distil_loss += global_distil_loss
torch.cuda.synchronize()
elapsed_time = time.time() - st_time
total_loss += global_loss
total_time += elapsed_time
# Logging
def get_log(log_loss, log_distil_loss, log_time):
return "train | epoch {:3d} | Iter: {:6d}/{:6d} | global iter: {:6d}/{:6d} | loss: {:.4f} | ds_loss: {:.4f} | lr: {:.4e} | scale: {:10.4f} | micro time: {:.3f} | step time: {:.3f}".format(
epoch,
step,
args.total_iters * args.gradient_accumulation_steps,
global_step,
args.total_iters,
log_loss,
log_distil_loss,
lr_scheduler.get_last_lr()[0],
optimizer.cur_scale if hasattr(optimizer, "cur_scale") else 0,
elapsed_time,
log_time,
)
if args.mid_log_num > 0:
mid_log_step = args.gradient_accumulation_steps // args.mid_log_num
mid_log_step = 1 if mid_log_step == 0 else mid_log_step
if step % mid_log_step == 0:
print_rank(get_log(global_loss, global_distil_loss, 0))
if global_step % args.log_interval == 0 and step % args.gradient_accumulation_steps == 0:
log_str = get_log(
total_loss / (args.log_interval * args.gradient_accumulation_steps),
total_distil_loss / (args.log_interval * args.gradient_accumulation_steps),
total_time / (args.log_interval))
print_rank("*" * 100)
print_rank(log_str)
print_rank(args.save)
print_rank("*" * 100)
save_rank(log_str, os.path.join(args.save, "log.txt"))
total_loss, total_distil_loss, total_time = 0.0, 0.0, 0.0
# Checkpointing
if args.save and args.save_interval and global_step % args.save_interval == 0 and step % args.gradient_accumulation_steps == 0:
save_dir_path = os.path.join(args.save, str(global_step))
if args.model_parallel:
raise NotImplementedError
else:
if dist.get_rank() == 0:
os.makedirs(save_dir_path, exist_ok=True)
print_rank(f"Model save to {save_dir_path}")
tokenizer.save_pretrained(save_dir_path)
model.module.save_pretrained(save_dir_path, safe_serialization=False)
dist.barrier()
# Evaluation
if args.eval_interval and global_step % args.eval_interval == 0 and step % args.gradient_accumulation_steps == 0:
curr_avg_loss = evaluate(args, tokenizer, model, dataset["dev"], "dev", epoch, device, adaptive_threshold)
if "adaptive" in args.type:
if curr_avg_loss >= prev_avg_loss + args.loss_eps:
adaptive_threshold += 0.1
adaptive_threshold = min(adaptive_threshold, 1.0)
prev_avg_loss = curr_avg_loss
model.train()
step += 1
if step % args.gradient_accumulation_steps == 0:
global_step += 1
if global_step > args.total_iters:
break
return model
def evaluate(args, tokenizer, model, dataset: LMTrainDataset, split, epoch, device, adaptive_threshold=None):
collate_fn = dataset.collate
if args.model_parallel:
raise NotImplementedError
else:
dp_world_size = dist.get_world_size()
dp_rank = dist.get_rank()
dp_group = None
loss_func = nn.CrossEntropyLoss()
print_rank("dp size", dp_world_size)
generation_config = GenerationConfig(
do_sample=args.do_sample,
top_p=args.top_p,
top_k=args.top_k,
temperature=args.temperature,
repetition_penalty=args.repetition_penalty,
max_length=args.max_length,
min_length=None,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True,
output_scores=False
)
sampler = DistributedSampler(dataset, shuffle=False, drop_last=False, rank=dp_rank, num_replicas=dp_world_size)
dataloader = DataLoader(
dataset, sampler=sampler, batch_size=args.eval_batch_size, num_workers=args.num_workers, collate_fn=collate_fn)
model.eval()
all_loss = 0.0
step = 0
all_response_ids = []
with torch.no_grad():
for it, (model_batch, no_model_batch, gen_data) in enumerate(tqdm(dataloader, desc="Evaluating", disable=(dist.get_rank() != 0))):
print_rank(f"{it}/{len(dataloader)}")
dataset.move_to_device(model_batch, no_model_batch, gen_data, device)
logits = model(**model_batch).logits
if args.model_parallel:
raise NotImplementedError
else:
loss = loss_func(logits.view(-1, logits.shape[-1]), no_model_batch["label"].view(-1))
max_new_tokens = args.max_length - gen_data["input_ids"].size(1)
if args.eval_gen:
gen_out = model.generate(
**gen_data,
generation_config=generation_config,
max_new_tokens=max_new_tokens)
full_ids = gen_out.sequences
full_ids = F.pad(
full_ids,
(0, args.max_length - full_ids.shape[1]),
value=tokenizer.pad_token_id,
)
response_ids = full_ids[:, gen_data["input_ids"].size(1):]
all_response_ids.append(response_ids)
dist.all_reduce(loss, dist.ReduceOp.SUM, group=dp_group)
loss = loss / dp_world_size
all_loss += loss.item()
step += 1
if args.eval_gen:
all_response_ids = torch.cat(all_response_ids, dim=0)
all_response_ids = all_gather(all_response_ids, dim=1, world_size=dp_world_size, group=dp_group, op="stack")
all_response_ids = all_response_ids.view(-1, all_response_ids.size(-1))
responses = tokenizer.batch_decode(all_response_ids, skip_special_tokens=True)
if get_rank() == 0:
if args.eval_gen:
references = dataset.answers
responses = responses[:len(references)]
res = compute_metrics(responses, references)
eval_dir = os.path.join(args.save, "eval", str(epoch))
print_rank(eval_dir)
os.makedirs(eval_dir, exist_ok=True)
with open(os.path.join(eval_dir, "answers.jsonl"), "w") as f:
for resp in responses:
f.write(json.dumps({"text": resp}) + "\n")
else:
res = {}
avg_loss = all_loss / step
if "adaptive" in args.type:
log_str = f"{split} | avg_loss: {avg_loss} | {res} | threshold: {adaptive_threshold}"
else:
log_str = f"{split} | avg_loss: {avg_loss} | {res}"
print_rank(log_str)
save_rank(log_str, os.path.join(args.save, "log.txt"))
return all_loss / step
def main():
torch.backends.cudnn.enabled = False
args = get_args()
initialize(args)
if dist.get_rank() == 0:
print_args(args)
with open(os.path.join(args.save, "args.json"), "w") as f:
json.dump(vars(args), f)
device = torch.cuda.current_device()
cur_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
save_rank("\n\n" + "="*30 + f" EXP at {cur_time} " + "="*30, os.path.join(args.save, "log.txt"))
with open(args.deepspeed_config, "r") as f:
ds_config = json.load(f)
ds_config["gradient_accumulation_steps"] = args.gradient_accumulation_steps
ds_config["train_micro_batch_size_per_gpu"] = args.batch_size
ds_config["gradient_clipping"] = args.clip_grad
ds_config["steps_per_print"] = 10000000
if not args.do_train:
ds_config["zero_optimization"]["stage"] = 0
args.fp32 = not ds_config["fp16"]["enabled"]
args.deepspeed_config = None
# get the tokenizer
tokenizer = get_tokenizer(args)
dataset = prepare_dataset(
args,
tokenizer,
)
dp_world_size = dist.get_world_size()
if args.do_train:
args.train_iters_per_epoch = int(len(dataset["train"]) / (args.batch_size * dp_world_size * args.gradient_accumulation_steps))
print_rank("Train iters per epoch", args.train_iters_per_epoch)
if args.total_iters is None:
args.total_iters = args.train_iters_per_epoch * args.epochs
if args.epochs is None:
args.epochs = math.ceil(args.total_iters / args.train_iters_per_epoch)
print_rank("total_iters", args.total_iters)
if args.save_interval == -1:
args.save_interval = args.train_iters_per_epoch
if args.eval_interval == -1:
args.eval_interval = args.train_iters_per_epoch
model, optimizer, lr_scheduler = setup_model_and_optimizer(args, ds_config, device, set_optim=args.do_train)
if args.teacher_model_type is None:
args.teacher_model_type = args.model_type
if args.teacher_model_path is not None:
teacher_model = get_teacher_model(args, device)
else:
teacher_model = None
if args.do_train:
model = finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset, device, teacher_model=teacher_model)
if args.do_eval:
evaluate(args, tokenizer, model, dataset["test"], "test", 0, device)
if __name__ == "__main__":
main()
================================================
FILE: generate.py
================================================
import time
import os
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
import deepspeed
import numpy as np
import json
from tqdm import tqdm
from transformers import mpu
from arguments import get_args
from data_utils.prompt_datasets import PromptDataset
from utils import print_args, initialize
from utils import print_rank, get_rank
from utils import save_rank
from utils import all_gather
from utils import get_tokenizer, get_model
torch.set_num_threads(4)
def setup_model(args, ds_config, device):
# get the model
model = get_model(args, device)
# get the optimizer and lr_scheduler
optimizer, lr_scheduler = None, None
model, _, _, _ = deepspeed.initialize(
model=model,
optimizer=optimizer,
args=args,
lr_scheduler=lr_scheduler,
mpu=mpu if args.model_parallel else None,
config_params=ds_config
)
# get the memory usage
print_rank("Model mem\n", torch.cuda.memory_summary())
return model
def prepare_dataset(args, tokenizer):
data = {}
data = PromptDataset(args, tokenizer, "train", data_path=args.data_dir, num=args.gen_num)
print_rank("gen num", len(data))
return data
def generate(args, tokenizer, model, dataset, device):
collate_fn = dataset.collate
if args.model_parallel:
dp_world_size = mpu.get_data_parallel_world_size()
dp_rank = mpu.get_data_parallel_rank()
dp_group = mpu.get_data_parallel_group()
else:
dp_world_size = dist.get_world_size()
dp_rank = dist.get_rank()
dp_group = None
sampler = DistributedSampler(dataset, shuffle=False, drop_last=False, rank=dp_rank, num_replicas=dp_world_size)
dataloader = DataLoader(
dataset, sampler=sampler, batch_size=args.eval_batch_size, num_workers=args.num_workers, collate_fn=collate_fn)
model.eval()
all_gen_ids = []
all_idxs = []
max_new_tokens = args.max_length - args.max_prompt_length
with torch.no_grad():
for it, (model_batch, no_model_batch) in enumerate(tqdm(dataloader, desc="Generating", disable=(dist.get_rank() != 0))):
dataset.move_to_device(model_batch, no_model_batch, device)
t_gen_out = model.generate(
**model_batch,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=max_new_tokens,
top_k=args.top_k,
top_p=args.top_p,
temperature=args.temperature,
do_sample=True,
return_dict_in_generate=True,
output_scores=False)
full_ids = t_gen_out.sequences
gen_ids = full_ids[:, model_batch["input_ids"].size(1):]
buffer = torch.ones(gen_ids.size(0), max_new_tokens, dtype=torch.long, device=gen_ids.device) * tokenizer.pad_token_id
buffer[:, :gen_ids.size(1)] = gen_ids
all_gen_ids.append(buffer)
all_idxs.append(no_model_batch["idx"])
all_idxs = all_gather(torch.cat(all_idxs, dim=0), dim=0, world_size=dp_world_size, group=dp_group).cpu().tolist()
all_gen_ids = all_gather(torch.cat(all_gen_ids, dim=0), dim=0, world_size=dp_world_size, group=dp_group).cpu().tolist()
if get_rank() == 0:
all_gen_strs = tokenizer.batch_decode(all_gen_ids, skip_special_tokens=True)
mean_lens = np.mean([len(tokenizer.encode(x)) for x in all_gen_strs[:100]])
log_str = f"gen | avg. lens: {mean_lens}"
print_rank(log_str)
save_rank(log_str, os.path.join(args.save, "log.txt"))
assert len(all_idxs) == len(all_gen_strs)
for idx, g in zip(all_idxs, all_gen_strs):
dataset.origin_data[idx]["gen_answer"] = g
with open(os.path.join(args.save, "raw.jsonl"), "w") as f:
for d in dataset.origin_data:
if "gen_answer" in d:
f.write(json.dumps(d) + "\n")
dist.barrier()
def main():
torch.backends.cudnn.enabled = False
args = get_args()
initialize(args)
if dist.get_rank() == 0:
print_args(args)
with open(os.path.join(args.save, "args.json"), "w") as f:
json.dump(vars(args), f)
device = torch.cuda.current_device()
cur_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
save_rank("\n\n" + "="*30 + f" EXP at {cur_time} " + "="*30, os.path.join(args.save, "log.txt"))
with open(args.deepspeed_config, "r") as f:
ds_config = json.load(f)
ds_config["steps_per_print"] = args.gradient_accumulation_steps
ds_config["zero_optimization"]["stage"] = 0
args.fp32 = not ds_config["fp16"]["enabled"]
args.deepspeed_config = None
# get the tokenizer
tokenizer = get_tokenizer(args)
dataset = prepare_dataset(
args,
tokenizer,
)
model = setup_model(args, ds_config, device)
generate(args, tokenizer, model, dataset, device)
if __name__ == "__main__":
main()
================================================
FILE: install.sh
================================================
export NCCL_DEBUG=""
# conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -c nvidia
# pip install transformers==4.42.4
pip install vllm==0.5.0
pip install deepspeed
pip install nltk
pip install numerize
pip install rouge-score
pip install torchtyping
pip install rich
pip install accelerate
pip install datasets
pip install sentencepiece
pip install protobuf
pip install peft
================================================
FILE: minillm/__init__.py
================================================
from deepspeed import DeepSpeedConfig
from typing import Optional
# from trlx.utils.loading import get_orchestrator, get_pipeline, get_trainer
from .sampler import PPOSampler
from .pipelines import PPOPipeline, LMPipeline
from .trainer import PPOTrainer
from .reward import Reward
def train(
args,
tokenizer,
reward_fn = None,
teacher_model=None,
prompt_data: Optional[str] = None,
eval_prompt_data: Optional[str] = None,
lm_data: Optional[str] = None,
eval_lm_data: Optional[str] = None,
ds_config: Optional[DeepSpeedConfig] = None,
):
trainer = PPOTrainer(
args=args,
tokenizer=tokenizer,
reward_fn=reward_fn,
ds_config=ds_config,
)
trainer.set_teacher_model(teacher_model)
ppo_pipeline = PPOPipeline(
args, tokenizer, "train", prompt_data, num=args.train_num
)
sampler = PPOSampler(
args, trainer, ppo_pipeline, chunk_size=args.chunk_size
)
sampler.run_sample(args.num_rollouts_per_device)
eval_ppo_pipeline = PPOPipeline(
args, trainer.tokenizer, "valid", eval_prompt_data, fix_prompts=True, num=args.dev_num
)
trainer.add_eval_pipeline(eval_ppo_pipeline)
lm_pipeline = LMPipeline(
args, trainer.tokenizer, "train", lm_data, num=args.train_num) if lm_data is not None else None
eval_lm_pipeline = LMPipeline(
args, trainer.tokenizer, "valid", eval_lm_data, num=args.dev_num) if eval_lm_data is not None else None
trainer.add_lm_pipeline(lm_pipeline, eval_lm_pipeline)
trainer.train()
return trainer
================================================
FILE: minillm/data_types.py
================================================
from dataclasses import dataclass
from typing import Iterable
from torchtyping import TensorType
@dataclass
class PromptElement:
"""
Dataclass for a single prompt, containing its string and tokenized form.
:param text: The prompt text.
:type text: str
:param tokens: The prompt tokens. Should be a long tensor
:type tokens: torch.Tensor
"""
text: str
tokens: TensorType["num_tokens"]
@dataclass
class PromptBatch:
"""
Batched PromptElement
:param text: An iterable of prompt texts.
:type text: Iterable[str]
:param tokens: A long tensor batch of prompt tokens.
:type tokens: torch.Tensor
"""
text: Iterable[str]
tokens: TensorType["batch_size", "num_tokens"]
@dataclass
class PPORLElement:
"""
:param query_tensor: The query tensor i.e. the prompt tokens.
Should be a long tensor.
:type query_tensor: torch.Tensor
:param response_tensor: The response tensor i.e. the output tokens.
Should be a long tensor.
:type response_tensor: torch.Tensor
:param logprobs: The log probabilities over all tokens in the vocabulary for
each token generated from the policy network
(i.e. the autoregressive model).
Should be a float tensor of same size as tokens,
with a dimension across the vocabulary.
:type logprobs: torch.Tensor
:param values: The values for each token generated from the value network or value head.
Should be a float tensor of same size as tokens.
:type values: torch.Tensor
:param rewards: The rewards for each token outputted in response.
Should be a float tensor of same size as tokens.
:type rewards: torch.Tensor
"""
query_tensor: TensorType["query_size"]
response_tensor: TensorType["response_size"]
lens: int
s_lens: int
mask: TensorType["response_size"]
logprobs: TensorType["response_size"]
rewards: TensorType["response_size"]
rev_kl: TensorType["response_size"]
w: TensorType["response_size"]
inf_mask: TensorType["response_size", "vocab_size"]
t_rewards: TensorType["response_size"]
ent_rewards: TensorType["response_size"]
@dataclass
class PPORLBatch:
"""
A batched version of the PPORLElement. See PPORLElement for more details on individual fields.
:param query_tensors: A batch of query tensors. Should be a long tensor.
:type query_tensors: torch.Tensor
:param response_tensors: A batch of response tensors. Should be a long tensor.
:type response_tensors: torch.Tensor
:param logprobs: A batch of log probabilities from policy
:type logprobs: torch.Tensor
:param values: A batch of values from value network
:type values: torch.Tensor
:param rewards: A batch of rewards
:type rewards: torch.Tensor
"""
query_tensors: TensorType["batch_size", "query_size"]
response_tensors: TensorType["batch_size", "response_size"]
lens: TensorType["batch_size"]
s_lens: TensorType["batch_size"]
mask: TensorType["batch_size", "response_size"]
logprobs: TensorType["batch_size", "response_size"]
rewards: TensorType["batch_size", "response_size"]
rev_kl: TensorType["batch_size", "response_size"]
w: TensorType["batch_size", "response_size"]
inf_mask: TensorType["batch_size", "response_size", "vocab_size"]
t_rewards: TensorType["batch_size", "response_size"]
ent_rewards: TensorType["batch_size", "response_size"]
================================================
FILE: minillm/losses.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
from torchtyping import TensorType
from .data_types import PPORLBatch
from .utils import whiten, get_entropy, get_x_entropy, get_log_probs
from transformers import mpu
from utils import all_gather, print_rank
class Loss():
def __init__(self, args, trainer):
self.args = args
self.trainer = trainer
def _get_cumsum_rewards(self, rewards):
full_rewards = torch.zeros_like(rewards[:, 0])
for t in reversed(range(rewards.size(1))):
full_rewards = self.args.gamma * full_rewards + rewards[:, t]
return full_rewards
def _get_advantages_and_returns(
self,
rewards: TensorType["batch_size", "response_size"],
response_length: int,
mask: TensorType["batch_size", "response_size"],
use_whitening: Optional[bool] = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
last_rw = 0
rw_reversed = []
rewards = rewards.float()
mask = mask.float()
lens = torch.cumsum(mask, dim=-1) # faster way
lens = mask - lens + lens[:, -1:None] # faster way
lens = torch.masked_fill(lens, lens==0, 1)
for t in reversed(range(response_length)):
rw_delta = rewards[:, t]
last_rw = rw_delta + self.args.gamma * last_rw
rw_reversed.append(last_rw)
rw = torch.stack(rw_reversed[::-1], dim=1)
rw = rw / lens
advantages = rw
if use_whitening:
advantages = whiten(advantages)
return advantages.detach()
def _pg_loss(
self,
logprobs: TensorType["batch_size", "response_size"],
old_logprobs: TensorType["batch_size", "response_size"],
advantages: TensorType["batch_size", "response_size"],
mask: TensorType["batch_size", "response_size"],
w: TensorType["batch_size", "response_size"],
):
"""PPO objective function.
References:
- https://stable-baselines.readthedocs.io/en/master/modules/ppo2.html
"""
n = mask.sum()
log_ratio = (logprobs - old_logprobs) * mask
ratio = torch.exp(log_ratio.float())
ratio = ratio * w
if any(torch.isinf(advantages).view(-1)):
print("[ERROR] advantage inf")
if any(torch.isinf(ratio).view(-1)):
print("[ERROR] ratio inf")
if any(torch.isnan(advantages).view(-1)):
print("[ERROR] advantage nan")
if any(torch.isnan(ratio).view(-1)):
print("[ERROR] ratio nan")
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(
ratio,
1.0 - self.args.cliprange,
1.0 + self.args.cliprange,
)
pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2).float() * mask) / n
return pg_loss
def _reg_loss(self, query_ids, response_ids, mask, logits, inf_mask, stats):
with torch.no_grad():
t_logits = self.trainer.compute_logits_and_log_probs(query_ids, response_ids, inf_mask, base="teacher", return_logprobs=False)
loss_exp_ent = 0
xent = get_x_entropy(logits, t_logits, inf_mask, mask, model_parallel=self.args.model_parallel)
s_ent = get_entropy(logits, inf_mask, mask, model_parallel=self.args.model_parallel)
loss_exp_ent = torch.sum((xent - s_ent) * mask) / mask.sum()
stats["reg_loss"] = loss_exp_ent.item()
return loss_exp_ent
def get_input_batch(self, ppo_batch: PPORLBatch, pt_batch):
query_tensors = ppo_batch.query_tensors
response_tensors = ppo_batch.response_tensors
ppo_input_batch = self.trainer.get_model_inputs(query_tensors, response_tensors)
pt_input_batch, _ = pt_batch
# merge batch
assert len(ppo_input_batch) == len(pt_input_batch), list(ppo_input_batch.keys())
input_batch = {}
for k in ppo_input_batch:
input_batch[k] = torch.cat([ppo_input_batch[k], pt_input_batch[k]], dim=0)
return input_batch
def ppo_loss(self, batch: PPORLBatch, logits):
stats = {}
query_tensors = batch.query_tensors
response_tensors = batch.response_tensors
lens = batch.lens
s_lens = batch.s_lens
mask = batch.mask
old_logprobs = batch.logprobs
old_rewards = batch.rewards
rev_kl = batch.rev_kl
w = batch.w
inf_mask = batch.inf_mask
response_length = response_tensors.shape[-1]
start = query_tensors.size(1) - 1 # "-1" for the first generated token AS TARGET
end = query_tensors.size(1) + response_tensors.size(1) - 1 # "remove the last token that does not have target"
logits = logits / self.args.temperature
logits = logits[:, start:end]
if inf_mask is not None:
logits = logits.masked_fill(inf_mask, -float("inf"))
tokens = torch.cat((query_tensors, response_tensors), dim=1)[
:, -self.trainer.max_length :
]
mask = self.trainer.get_mask(tokens)[:, start:end]
logprobs = get_log_probs(logits, response_tensors, mask, inf_mask, model_parallel=self.args.model_parallel)
advantages = self._get_advantages_and_returns(
old_rewards, response_length, mask
)
loss = self._pg_loss(
logprobs=logprobs,
old_logprobs=old_logprobs,
advantages=advantages,
mask=mask,
w=w,
)
stats["pg_loss"] = loss.item()
single_step_reg_loss = self._reg_loss(query_tensors, response_tensors, mask, logits, inf_mask, stats)
stats["reg_loss"] = single_step_reg_loss.item()
if self.args.single_step_reg:
loss += single_step_reg_loss
stats["rl_loss"] = loss.item()
with torch.no_grad():
# generation values for reward
cumsum_rewards = self._get_cumsum_rewards(old_rewards)
rev_kl = torch.sum(rev_kl, dim=-1)
if self.args.length_norm:
cumsum_rewards = cumsum_rewards / lens
rev_kl = rev_kl / s_lens
cumsum_rewards = all_gather(cumsum_rewards, dim=0, world_size=self.trainer.dp_world_size, group=self.trainer.dp_group).mean(dim=0).item()
rev_kl = all_gather(rev_kl, dim=0, world_size=self.trainer.dp_world_size, group=self.trainer.dp_group).mean(dim=0).item()
lens = all_gather(lens, dim=0, world_size=self.trainer.dp_world_size, group=self.trainer.dp_group).float().mean(dim=0).item()
s_lens = all_gather(s_lens, dim=0, world_size=self.trainer.dp_world_size, group=self.trainer.dp_group).float().mean(dim=0).item()
stats["reward"] = cumsum_rewards
stats["rev_kl"] = rev_kl
stats["mixed_lens"] = lens
stats["stu_lens"] = s_lens
return loss, stats
def pt_loss(self, batch, logits):
stats = {}
model_batch, no_model_batch = batch
loss_mask = (no_model_batch["label"] != -100).int()
if self.args.model_parallel:
lm_losses = mpu.parallel_cross_entropy(logits.contiguous().float(), no_model_batch["label"]).view(-1)
lm_loss = (lm_losses * loss_mask.view(-1)).sum(-1) / loss_mask.view(-1).sum(-1)
else:
loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
lm_loss = loss_fn(logits.view(-1, logits.size(-1)), no_model_batch["label"].view(-1))
distil_loss = 0
if self.trainer.teacher_model is not None and self.args.kd_ratio is not None:
with torch.no_grad():
teacher_outputs = self.trainer.teacher_model(**model_batch, return_dict=True, use_cache=False)
teacher_logits = teacher_outputs.logits
if self.args.model_parallel:
distil_losses = mpu.parallel_soft_cross_entropy_loss(logits.float(), teacher_logits.float())
distil_losses = distil_losses.view(-1)
distil_loss = (distil_losses * loss_mask.view(-1)).sum(-1) / loss_mask.view(-1).sum(-1)
else:
teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
inf_mask = torch.isinf(logits)
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
prod_probs = torch.masked_fill(teacher_probs * logprobs, inf_mask, 0)
x = torch.sum(prod_probs, dim=-1).view(-1)
distil_loss = -torch.sum(x * loss_mask.view(-1), dim=0) / torch.sum(loss_mask.view(-1), dim=0)
loss = (1-self.args.kd_ratio) * lm_loss + self.args.kd_ratio * distil_loss
stats["pt_loss"] = loss.item()
stats["lm_loss"] = lm_loss.item()
stats["ds_loss"] = distil_loss.item()
return loss, stats
================================================
FILE: minillm/model.py
================================================
import torch.nn as nn
from transformers import (
AutoConfig,)
from utils import get_model
class PPOModel(nn.Module):
def __init__(self, args, device):
super().__init__()
self.model_parallel = args.model_parallel
self.config = AutoConfig.from_pretrained(args.model_path)
self.base_model = get_model(args, device)
self.base_model.eval() # no dropout for RL
def forward(self, **x):
base_model_outputs = self.base_model(**x)
return base_model_outputs
def generate(self, **x):
return self.base_model.generate(**x)
def set_force_gradient_checkpointing(self, value):
self.base_model.set_force_gradient_checkpointing(value)
================================================
FILE: minillm/pipelines.py
================================================
import os
import json
import torch
import random
import numpy as np
from torch.utils.data import DataLoader, DistributedSampler
from transformers import mpu
import torch.distributed as dist
from data_utils.distributed_indexed import DistributedMMapIndexedDataset
from torch.distributed import get_rank, get_world_size
from utils import print_rank
class PPOPipeline():
def __init__(self, args, tokenizer, split, ppo_data_path=None, fix_prompts=False, num=-1):
super().__init__()
self.tokenizer = tokenizer
self.args = args
self.tokenizer = tokenizer
self.split = split
self.pad_id = self.tokenizer.eos_token_id
self.max_length = args.max_length
self.rng_ppo = random.Random(args.seed_ppo)
self.min_prompt_length = args.min_prompt_length
self.max_prompt_length = args.max_prompt_length
self.ppo_ctx = DistributedMMapIndexedDataset(ppo_data_path, f"{split}", get_rank(), get_world_size())
self.ppo_raw, self.ppo_answers = None, None
if os.path.exists(os.path.join(ppo_data_path, f"{split}.jsonl")):
with open(os.path.join(ppo_data_path, f"{split}.jsonl")) as f:
self.ppo_raw = [json.loads(line) for line in f.readlines()]
self.ppo_answers = [x["output"] if isinstance(x["output"], list) else [x["output"]] for x in self.ppo_raw]
self.num = min(num, len(self.ppo_ctx)) if num > 0 else len(self.ppo_ctx)
self.fix_prompts = fix_prompts
self.prompt_lengths = [None for _ in range(num)]
print_rank(f"Num PPO instances: {len(self.ppo_ctx)}")
def __len__(self):
return self.num
def __getitem__(self, index: int):
data = self.ppo_ctx[index].astype(int)
assert len(data) <= self.max_prompt_length
if self.args.model_type!="qwen" and 65535 in data:
source_len = np.where(data==65535)[0][0]
prompt = data[:source_len]
response = data[source_len+1:]
else:
prompt = data
response = None
# return prompt, rest
return prompt, response
def collate(self, samples):
bs = len(samples)
max_prompt_length = self.max_prompt_length
model_batch = {
"input_ids": torch.ones(bs, max_prompt_length, dtype=torch.long) * self.pad_id,
"attention_mask": torch.zeros(bs, max_prompt_length, dtype=torch.long),
}
no_model_batch = {
"full_ids": torch.ones(bs, self.max_length, dtype=torch.long) * self.pad_id,
"full_attention_mask": torch.zeros(bs, self.max_length, dtype=torch.long),
"full_label_ids": torch.ones(bs, self.max_length, dtype=torch.long) * -100,
}
for i, (prompt, response) in enumerate(samples):
# left padding
model_batch["input_ids"][i][-len(prompt):] = torch.tensor(prompt, dtype=torch.long)
model_batch["attention_mask"][i][-len(prompt):] = 1
if response is not None:
full_ids = np.concatenate([prompt, response], axis=0)
no_model_batch["full_ids"][i][:len(full_ids)-1] = torch.tensor(full_ids[:-1], dtype=torch.long)
no_model_batch["full_attention_mask"][i][:len(full_ids)-1] = 1.0
no_model_batch["full_label_ids"][i][len(prompt)-1:len(full_ids)-1] = torch.tensor(response, dtype=torch.long)
return model_batch, no_model_batch
def move_to_device(self, model_batch, no_model_batch, device):
for k in model_batch:
model_batch[k] = model_batch[k].to(device)
for k in no_model_batch:
no_model_batch[k] = no_model_batch[k].to(device)
return model_batch, no_model_batch
def create_loader(self, batch_size: int, shuffle=False, drop_last: bool = False, num_workers: int = 0) -> DataLoader:
if self.args.model_parallel:
dp_world_size = mpu.get_data_parallel_world_size()
dp_rank = mpu.get_data_parallel_rank()
else:
dp_world_size = dist.get_world_size()
dp_rank = dist.get_rank()
sampler = DistributedSampler(self, shuffle=shuffle, drop_last=drop_last, rank=dp_rank, num_replicas=dp_world_size)
return DataLoader(
self, sampler=sampler, batch_size=batch_size, collate_fn=self.collate, num_workers=num_workers
)
class LMPipeline():
def __init__(self, args, tokenizer, split, lm_data_path=None, num=-1):
super().__init__()
self.tokenizer = tokenizer
self.args = args
self.tokenizer = tokenizer
self.split = split
self.pad_id = self.tokenizer.eos_token_id
self.max_length = args.max_length
self.rng_lm = random.Random(args.seed_lm)
self.lm_ctx = DistributedMMapIndexedDataset(lm_data_path, f"{split}", get_rank(), get_world_size())
self.num = min(num, len(self.lm_ctx)) if num > 0 else len(self.lm_ctx)
print_rank(f"Num LM instances: {len(self.lm_ctx)}")
def __len__(self):
return self.num
def __getitem__(self, index):
return self._get_lm(index)
def _get_lm(self, index):
data = self.lm_ctx[index]
input_ids = data.astype(int)
return {
"input_ids": input_ids[:self.max_length]
}
def _process_lm(self, i, samp, model_data, no_model_data):
input_ids = samp["input_ids"]
source_len = 1
if self.args.model_type!="qwen" and 65535 in input_ids:
source_len = np.where(input_ids==65535)[0][0]
input_ids = np.concatenate([input_ids[:source_len], input_ids[source_len+1:]], axis=0)
input_ids = input_ids[:self.max_length]
input_len = len(input_ids)
model_data["input_ids"][i][:input_len-1] = torch.tensor(input_ids[:-1], dtype=torch.long)
model_data["attention_mask"][i][:input_len-1] = 1.0
if self.args.model_type in ["gpt2"]:
model_data["position_ids"][i][:input_len-1] = torch.arange(0, input_len-1, dtype=torch.long)
no_model_data["label"][i][:input_len-1] = torch.tensor(input_ids[1:], dtype=torch.long)
no_model_data["label"][i][:source_len-1] = -100
no_model_data["loss_mask"][i][:input_len-1] = 1.0
no_model_data["loss_mask"][i][:source_len-1] = 0
def move_to_device(self, model_batch, no_model_batch, device):
for k in model_batch:
model_batch[k] = model_batch[k].to(device)
for k in no_model_batch:
no_model_batch[k] = no_model_batch[k].to(device)
return model_batch, no_model_batch
def collate(self, samples):
bs = len(samples)
max_length = self.max_length
model_data = {
"input_ids": torch.ones(bs, max_length, dtype=torch.long) * self.pad_id,
"attention_mask": torch.zeros(bs, max_length, dtype=torch.long)
}
if self.args.model_type in ["gpt2"]:
model_data["position_ids"] = torch.zeros(bs, max_length, dtype=torch.long)
no_model_data = {
"label": torch.ones(bs, self.max_length, dtype=torch.long) * -100,
"loss_mask": torch.zeros(bs, max_length)
}
for i, samp in enumerate(samples):
self._process_lm(i, samp, model_data, no_model_data)
return model_data, no_model_data
def create_loader(self, batch_size: int, shuffle=False, drop_last: bool = False, num_workers: int = 0) -> DataLoader:
if self.args.model_parallel:
dp_world_size = mpu.get_data_parallel_world_size()
dp_rank = mpu.get_data_parallel_rank()
else:
dp_world_size = dist.get_world_size()
dp_rank = dist.get_rank()
sampler = DistributedSampler(self, shuffle=shuffle, drop_last=drop_last, rank=dp_rank, num_replicas=dp_world_size)
return DataLoader(
self, sampler=sampler, batch_size=batch_size, collate_fn=self.collate, num_workers=num_workers
)
================================================
FILE: minillm/reward.py
================================================
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
mpu)
class Reward():
def __init__(self, args, tokenizer: AutoTokenizer, model: AutoModelForCausalLM):
self.args = args
self.tokenizer = tokenizer
self.model = model
self.pad_token_id = tokenizer.pad_token_id
self.eos_token_id = tokenizer.eos_token_id
def get_input_batch(self, input_ids, gen_ids, output_pos=True):
full_ids = torch.cat([input_ids, gen_ids], dim=-1)
attention_mask = (full_ids != self.pad_token_id)
model_inputs = {
"input_ids": full_ids,
"attention_mask": attention_mask,
"use_cache": False
}
if (self.args.model_type in ["gpt2"]) and output_pos:
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
position_ids.masked_fill_(~attention_mask, 0)
model_inputs["position_ids"] = position_ids
return model_inputs
def reward_fn(self, input_ids, gen_ids, inf_mask=None, output_pos=True):
# not include eos token
self.model.eval()
# input_ids = input_ids.repeat(1, 1)
model_inputs = self.get_input_batch(input_ids, gen_ids, output_pos=output_pos)
with torch.no_grad():
outputs = self.model(**model_inputs)
logits = outputs.logits # (B, L, V)
if self.args.model_parallel:
logits = logits - mpu.parallel_mean(logits.float(), dim=-1).unsqueeze(-1)
else:
logits = logits - torch.mean(logits, dim=-1, keepdim=True)
mask = model_inputs["attention_mask"]
logits = logits * mask.unsqueeze(-1) # set logits output by padding to 0
logits = logits[:, input_ids.size(-1)-1:, :]
mask = mask[:, input_ids.size(-1)-1:]
if self.args.model_parallel:
selection_value = mpu.parallel_gather(logits[:, :-1, :], -1, model_inputs["input_ids"][:, input_ids.size(-1):, None]).squeeze(-1)
else:
selection_value = torch.gather(logits[:, :-1, :], -1, model_inputs["input_ids"][:, input_ids.size(-1):, None]).squeeze(-1)
current_logits = logits[:, :-1, :]
if self.args.model_parallel:
next_state_value = mpu.parallel_logsumexp(current_logits.float(), dim=-1)
else:
next_state_value = torch.logsumexp(current_logits, dim=-1)
next_state_value = next_state_value * mask[:, :-1]
raw_next_state_value = next_state_value
scores = selection_value - next_state_value
assert all((~torch.isinf(scores.view(-1))) & (~torch.isnan(scores.view(-1))))
assert scores.size() == gen_ids.size()
return {
"rewards": scores,
"inf_mask": inf_mask
}
================================================
FILE: minillm/sampler.py
================================================
import torch
import os
from .data_types import PromptBatch, PPORLElement
from .pipelines import PPOPipeline
from .trainer import PPOTrainer
from utils import get_rank, print_rank, all_gather, save_rank
from .utils import get_rev_kl
from transformers import mpu
class PPOSampler():
"""
Orchestrator prepares data for PPO training.
Transforms samples from `pipeline` into `PPOBatch` and pushes them into trainer's `store`
"""
def __init__(
self,
args,
trainer: PPOTrainer,
pipeline: PPOPipeline,
chunk_size: int = 512,
):
self.args = args
self.pipeline = pipeline
self.trainer = trainer
self.chunk_size = chunk_size
self.pipeline_loader = self.pipeline.create_loader(
self.chunk_size, shuffle=True, drop_last=True, num_workers=self.args.num_workers
)
self.pipeline_iterator = iter(self.pipeline_loader)
self.trainer.set_sampler(self)
self.epochs = 0
def run_sample(self, num_rollouts_per_device: int = 1024, iter_count: int = 0):
"""
Takes `num_rollouts_per_device` prompts from `pipeline`, samples model and computes the
KL againts a reference model. It then appends PPOElements to trainer's `store`
"""
ppo_rl_elements = []
while len(ppo_rl_elements) < num_rollouts_per_device:
if ((not self.args.model_parallel) or mpu.get_model_parallel_rank()) == 0:
print(f"Rank {get_rank()}: Number Sampling Elements {len(ppo_rl_elements)} / {num_rollouts_per_device}")
try:
batch: PromptBatch = next(self.pipeline_iterator)
except StopIteration:
self.epochs += 1
print_rank(f"Another outer ppo epoch, outer ppo epoch: {self.epochs}")
save_rank(f"Another outer ppo epoch, outer ppo epoch: {self.epochs}", os.path.join(self.args.save, "log.txt"))
self.pipeline_loader.sampler.set_epoch(self.epochs)
self.pipeline_iterator = iter(self.pipeline_loader)
batch = next(self.pipeline_iterator)
batch, no_model_batch = batch
n = batch["input_ids"].size(0)
batch, no_model_batch = self.pipeline.move_to_device(batch, no_model_batch, self.trainer.device)
query_ids = batch["input_ids"]
# generate and compute rollout scores
with torch.no_grad():
mode = "base"
gen_out = self.trainer.generate(**batch, return_dict_in_generate=True, mode=mode, teacher_mixed_sample=(self.args.teacher_mixed_alpha is not None), output_scores=True)
full_ids = gen_out.sequences
response_ids = full_ids[:, query_ids.size(1):] # remove prompt (may include start token)
mask = (full_ids != self.trainer.tokenizer.pad_token_id)[:, query_ids.size(-1)-1:query_ids.size(-1)+response_ids.size(-1)-1]
lens = torch.sum(mask, dim=-1)
gen_logits = gen_out.scores # NOTE: [b, s, h_p]
inf_mask = torch.isinf(gen_logits)
scores = self.trainer.reward_fn(query_ids, response_ids, inf_mask=inf_mask)
t_rewards = scores["rewards"]
inf_mask = scores["inf_mask"]
_, rollout_logprobs = self.trainer.compute_logits_and_log_probs(query_ids, response_ids, inf_mask=inf_mask, base=mode)
# student generation features
if self.args.teacher_mixed_alpha is not None:
s_gen_out = self.trainer.generate(**batch, return_dict_in_generate=True, mode=mode, output_scores=True)
s_full_ids = s_gen_out.sequences
s_response_ids = s_full_ids[:, query_ids.size(1):]
s_inf_mask = torch.isinf(s_gen_out.scores)
s_response_ids = s_full_ids[:, query_ids.size(1):] # remove prompt (may include start token)
s_scores = self.trainer.reward_fn(query_ids, s_response_ids, inf_mask=s_inf_mask)
s_t_rewards = s_scores["rewards"]
s_inf_mask = s_scores["inf_mask"]
_, s_rollout_logprobs = self.trainer.compute_logits_and_log_probs(query_ids, s_response_ids, inf_mask=s_inf_mask, base=mode)
s_mask = (s_full_ids != self.trainer.tokenizer.pad_token_id)[:, query_ids.size(-1)-1:query_ids.size(-1)+s_response_ids.size(-1)-1]
s_lens = torch.sum(s_mask, dim=-1)
else:
s_t_rewards = t_rewards
s_rollout_logprobs = rollout_logprobs
s_mask = mask
s_lens = lens
rev_kl = get_rev_kl(s_t_rewards, s_rollout_logprobs, s_mask)
if self.args.teacher_mixed_alpha is not None:
with torch.no_grad():
_, t_rollout_logprobs = self.trainer.compute_logits_and_log_probs(query_ids, response_ids, inf_mask=inf_mask, base="teacher") # recompute because of the fp16 loss
# get logprobs and the importance sampling weight w
with torch.no_grad():
if self.args.teacher_mixed_alpha is not None:
_, raw_logprobs = self.trainer.compute_logits_and_log_probs(query_ids, response_ids, inf_mask=inf_mask, base="base") # raw_logprobs: compute using the new model
logprobs = raw_logprobs
mix_probs = (1 - self.args.teacher_mixed_alpha) * torch.exp(rollout_logprobs.float()) + self.args.teacher_mixed_alpha * torch.exp(t_rollout_logprobs.float())
mix_logprobs = torch.log(mix_probs)
log_w = logprobs - mix_logprobs
w = torch.exp(log_w) # importance sampling weight
else:
raw_logprobs = rollout_logprobs
logprobs = rollout_logprobs
w = torch.ones_like(logprobs)
# get ent_rewards
ent_rewards = -logprobs
rewards = t_rewards + ent_rewards
if self.args.reward_scaling is not None:
rewards = rewards / self.args.reward_scaling
clip_reward = self.args.cliprange_reward
if clip_reward:
rewards = torch.clip(rewards, -clip_reward, clip_reward)
query_ids = query_ids.cpu()
response_ids = response_ids.cpu()
lens = lens.cpu()
s_lens = s_lens.cpu()
mask = mask.cpu()
logprobs = logprobs.cpu()
rewards = rewards.cpu()
rev_kl = rev_kl.cpu()
w = w.cpu()
inf_mask = inf_mask.cpu()
new_ppo_rl_elements = [
PPORLElement(
query_tensor=query_ids[i],
response_tensor=response_ids[i],
lens=lens[i],
s_lens=s_lens[i],
mask=mask[i],
logprobs=logprobs[i],
rewards=rewards[i],
rev_kl=rev_kl[i],
w=w[i],
inf_mask=inf_mask[i],
t_rewards=t_rewards[i],
ent_rewards=ent_rewards[i]
)
for i in range(n)
]
ppo_rl_elements.extend(new_ppo_rl_elements)
ppo_rl_elements = ppo_rl_elements[:num_rollouts_per_device]
# Push samples and rewards to trainer's rollout storage
self.trainer.push_to_store(ppo_rl_elements)
if self.args.save_rollout:
all_query_ids = all_gather(torch.stack([e.query_tensor for e in ppo_rl_elements], dim=0).to(self.trainer.device))
all_response_ids = all_gather(torch.stack([e.response_tensor for e in ppo_rl_elements], dim=0).to(self.trainer.device))
all_entropy = all_gather(torch.stack([e.entropy for e in ppo_rl_elements], dim=0).to(self.trainer.device))
rollout_save_path = os.path.join(self.args.save, "rollout_history", str(iter_count))
if get_rank() == 0:
os.makedirs(rollout_save_path, exist_ok=True)
torch.save((all_query_ids, all_response_ids, all_entropy), os.path.join(rollout_save_path, "all.pt"))
================================================
FILE: minillm/storages.py
================================================
import json
import os
import time
from abc import abstractmethod
from typing import Any, Callable, Iterable
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
import torch.distributed as dist
from .data_types import PPORLElement, PPORLBatch
from utils import get_rank
class BaseRolloutStore(Dataset):
def __init__(self, capacity=-1):
self.history: Iterable[Any] = None
self.capacity = capacity
@abstractmethod
def push(self, exps: Iterable[Any]):
"""
Push experiences to rollout storage
"""
pass
def __getitem__(self, index: int) -> PPORLElement:
return self.history[index]
def __len__(self) -> int:
return len(self.history)
@abstractmethod
def create_loader(
self,
batch_size: int,
shuffle: bool,
prep_fn: Callable = None,
num_workers: int = 0,
drop_last: bool = False
) -> DataLoader:
"""
Create a dataloader for the rollout store
:param prep_fn: Applied to RLElement after collation (typically tokenizer)
:type prep_fn: Callable
"""
pass
@abstractmethod
def broadcast(self, batch, src=0, group=None):
pass
@abstractmethod
def move_to_device(self, batch, device):
pass
class PPORolloutStorage(BaseRolloutStore):
"""
Rollout storage for training PPO
"""
def __init__(self, pad_token_id, seed):
super().__init__()
self.pad_token_id = pad_token_id
self.history: Iterable[PPORLElement] = [None]
self.rng = torch.Generator()
self.rng.manual_seed(seed)
def push(self, exps: Iterable[PPORLElement]):
self.history += exps
def save(self, path):
def exp_to_dict(exp):
return {k: v for k, v in exp.__dict__.items()}
data = [exp_to_dict(exp) for exp in self.history]
torch.save(data, os.path.join(path, f"{get_rank()}.pkl"))
def load(self, path):
data = torch.load(os.path.join(path, f"history_{get_rank()}.pkl"), map_location="cpu")
self.history = [PPORLElement(**d) for d in data]
def clear_history(self):
self.history = []
def export_history(self, location: str):
assert os.path.exists(location)
fpath = os.path.join(location, f"epoch-{str(time.time())}.json")
def exp_to_dict(exp):
return {k: v.cpu().tolist() for k, v in exp.__dict__.items()}
data = [exp_to_dict(exp) for exp in self.history]
with open(fpath, "w") as f:
f.write(json.dumps(data, indent=2))
def __getitem__(self, index: int) -> PPORLElement:
return self.history[index]
def __len__(self) -> int:
return len(self.history)
def collate(self, elems: Iterable[PPORLElement]):
if any([e is None for e in elems]):
print(elems)
return PPORLBatch(
# Left padding of already left-padded queries
pad_sequence(
[elem.query_tensor.flip(0) for elem in elems],
padding_value=self.pad_token_id,
batch_first=True,
).flip(1),
# Right pad the rest, to have a single horizontal query/response split
pad_sequence(
[elem.response_tensor for elem in elems],
padding_value=self.pad_token_id,
batch_first=True,
),
torch.tensor([elem.lens for elem in elems], dtype=torch.long),
torch.tensor([elem.s_lens for elem in elems], dtype=torch.long),
pad_sequence(
[elem.mask for elem in elems],
padding_value=0.0,
batch_first=True,
),
pad_sequence(
[elem.logprobs for elem in elems],
padding_value=0.0,
batch_first=True,
),
pad_sequence(
[elem.rewards for elem in elems],
padding_value=0.0,
batch_first=True,
),
pad_sequence(
[elem.rev_kl for elem in elems],
padding_value=0.0,
batch_first=True,
),
pad_sequence(
[elem.w for elem in elems],
padding_value=0.0,
batch_first=True,
),
pad_sequence(
[elem.inf_mask for elem in elems],
padding_value=0,
batch_first=True,
),
pad_sequence(
[elem.t_rewards for elem in elems],
padding_value=0.0,
batch_first=True,
),
pad_sequence(
[elem.ent_rewards for elem in elems],
padding_value=0.0,
batch_first=True,
),
)
def create_loader(self, batch_size: int, shuffle=False, drop_last: bool = False, num_workers: int = 0) -> DataLoader:
# sampler = DistributedSampler(self, shuffle=shuffle, drop_last=drop_last)
# we don't use distributed sampler because the dataset on each device is different
return DataLoader(
self, batch_size=batch_size, collate_fn=self.collate, num_workers=num_workers, shuffle=shuffle, drop_last=drop_last, generator=self.rng
)
def broadcast(self, batch: PPORLBatch, src=0, group=None):
for k, v in batch.__dict__.items():
dist.broadcast(batch.__dict__[k], src=src, group=group)
def move_to_device(self, batch: PPORLBatch, device):
for k, v in batch.__dict__.items():
batch.__dict__[k] = batch.__dict__[k].to(device)
================================================
FILE: minillm/trainer.py
================================================
import json
import os
import deepspeed
from time import time
from typing import Optional, Tuple
from collections import defaultdict
import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.optim import AdamW
from rich.console import Console
from rich.table import Table
from tqdm import tqdm
from transformers import (
AutoTokenizer,
GenerationConfig,
mpu)
from transformers import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup
from .utils import (
get_scheduler_class,
get_log_probs,
get_rev_kl,
significant
)
from .model import (
PPOModel
)
from .pipelines import PPOPipeline, LMPipeline
from .storages import PPORolloutStorage
from .losses import Loss
from utils import print_rank, save_rank, get_rank, all_gather, save_parallel
from rouge_metric import compute_metrics
class PPOTrainer():
"""
RL model trainer with an `accelerate` based backend
"""
def __init__(self, args, tokenizer: AutoTokenizer, reward_fn, ds_config):
self.args = args
self.max_length = args.max_length
self.ds_config = ds_config
self.reward_fn = reward_fn
self.device = torch.cuda.current_device()
if int(os.environ.get("WORLD_SIZE", 1)) > 1:
dist.barrier(device_ids=[int(os.environ.get("LOCAL_RANK", 0))])
if args.model_parallel:
raise NotImplementedError
else:
self.dp_world_size = dist.get_world_size()
self.dp_rank = dist.get_rank()
self.dp_group = None
self.model = PPOModel(args, self.device)
if args.model_parallel:
raise NotImplementedError
else:
if dist.get_rank() == 0:
print(' > number of parameters: {}M'.format(
int(sum([p.nelement() for p in self.model.parameters()]) / 1e6)), flush=True)
self.sampler = None
self.teacher_model = None
self.opt = self.setup_optimizer()
self.scheduler = self.setup_scheduler()
self.model, self.opt, self.scheduler = self.setup_ds(self.model, self.opt, self.scheduler)
self.tokenizer = tokenizer
self.store = PPORolloutStorage(self.tokenizer.pad_token_id, self.args.seed_ppo + self.dp_rank)
self.store.clear_history()
self.losses = Loss(args, self)
self.generate_kwargs = dict(
do_sample=args.do_sample,
top_p=args.top_p,
top_k=args.top_k,
temperature=args.temperature,
max_length=args.max_length,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
)
def set_teacher_model(self, model):
self.teacher_model = model
def set_sampler(self, sampler):
self.sampler = sampler
def setup_optimizer(self):
"""
Returns an optimizer derived from an instance's TRLConfig
"""
optimizer = AdamW(
self.model.parameters(),
lr=self.args.lr,
betas=[0.9, 0.95],
eps=1.0e-8,
weight_decay=1.0e-6
)
return optimizer
def setup_scheduler(self):
"""
Returns a learning rate scheduler derived from an instance's TRLConfig
"""
if self.args.scheduler_name == "constant_trm":
scheduler = get_constant_schedule_with_warmup(self.opt, num_warmup_steps=self.args.warmup_iters)
elif self.args.scheduler_name == "cosine_trm":
scheduler = get_cosine_schedule_with_warmup(self.opt, num_warmup_steps=self.args.warmup_iters, num_training_steps=self.args.total_iters)
else:
scheduler_class = get_scheduler_class(self.args.scheduler_name)
scheduler = scheduler_class(self.opt, eta_min=self.args.lr_min, T_max=self.args.total_iters)
return scheduler
def setup_ds(self, model, optimizer=None, scheduler=None):
model, optimizer, _, scheduler = deepspeed.initialize(
model=model,
optimizer=optimizer,
args=self.args,
lr_scheduler=scheduler,
mpu=mpu if self.args.model_parallel else None,
config_params=self.ds_config
)
return model, optimizer, scheduler
def add_eval_pipeline(self, eval_pipeline: PPOPipeline):
"""Adds pipeline from with validation prompts"""
self.eval_pipeline = eval_pipeline
def add_lm_pipeline(self, lm_pipeline: LMPipeline, eval_lm_pipeline: LMPipeline):
self.lm_pipeline = lm_pipeline
self.eval_lm_pipeline = eval_lm_pipeline
def get_model_inputs(
self,
query_tensors,
response_tensors,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
tokens = torch.cat((query_tensors, response_tensors), dim=1)[
:, -self.max_length :
]
attention_mask = self.get_mask(tokens)
batch = {
"input_ids": tokens,
"attention_mask": attention_mask
}
if self.args.model_type in ["gpt2"]:
# For a proper positional encoding in case of left padding
position_ids = attention_mask.cumsum(-1) - 1
position_ids.masked_fill_(attention_mask.eq(0), 0)
batch["position_ids"] = position_ids
return batch
def get_mask(self, tokens):
attention_mask = (
tokens.not_equal(self.tokenizer.pad_token_id).long()
)
return attention_mask
def forward_model(self, batch):
outputs = self.model(
**batch,
return_dict=True,
use_cache=False,
)
return outputs
def compute_logits_and_log_probs(self, query_ids, response_ids, inf_mask=None, base="base", return_logprobs=True):
batch = self.get_model_inputs(
query_ids, response_ids
)
if base == "base":
model_cls = self.model.module.forward
elif base == "teacher":
model_cls = self.teacher_model
else:
raise NotImplementedError
outputs = model_cls(
**batch,
return_dict=True,
use_cache=False
)
logits = outputs.logits
logits = logits / self.args.temperature
start = query_ids.size(1) - 1
end = query_ids.size(1) + response_ids.size(1) - 1
logits = logits[:, start:end]
if inf_mask is not None:
logits = logits.masked_fill(inf_mask, -float("inf"))
mask = batch["attention_mask"][:, start:end]
if return_logprobs:
logprobs = get_log_probs(logits, response_ids, mask, inf_mask, model_parallel=self.args.model_parallel)
return logits, logprobs
return logits
def train(self):
"""
Samples batches from `self.store`, updates model and periodically evaluates it on `self.eval_dataloader`
"""
self.prepare_learning()
self.iter_count = 1
self.global_iter_count = 1
self.nth_evaluation = 0
self.evaluate()
print_rank("Total Steps:", self.total_steps, "Data Epochs:", self.args.epochs)
lm_epochs = 0
logging_stats = defaultdict(float)
for training_epoch in range(self.args.training_epochs):
for ppo_epoch in range(self.n_updates_per_batch):
for it, batch in enumerate(self.train_dataloader):
if self.lm_pipeline is not None:
try:
lm_batch = next(self.lm_iterator)
except StopIteration:
lm_epochs += 1
print_rank(f"Another lm epoch, lm epochs: {lm_epochs}")
save_rank(f"Another lm epoch, lm epochs: {lm_epochs}", os.path.join(self.args.save, "log.txt"))
self.lm_dataloader.sampler.set_epoch(lm_epochs)
self.lm_iterator = iter(self.lm_dataloader)
lm_batch = next(self.lm_iterator)
self.store.move_to_device(batch, self.device)
self.lm_pipeline.move_to_device(*lm_batch, self.device)
stats = {}
if self.args.model_parallel:
raise NotImplementedError
if self.args.gradient_checkpointing:
try: self.model.module.set_force_gradient_checkpointing(True)
except: self.model.module.base_model.set_force_gradient_checkpointing(True)
input_batch = self.losses.get_input_batch(batch, lm_batch)
logits = self.forward_model(input_batch).logits
ppo_logits = logits[:batch.query_tensors.size(0)]
lm_logits = logits[batch.query_tensors.size(0):]
# forward
forward_time = time()
# compute rl-related loss on explored data
rl_loss, rl_loss_stats = self.losses.ppo_loss(batch, ppo_logits)
stats.update(rl_loss_stats)
# compute lm-related loss on pre-training data
pt_loss, pt_loss_stats = self.losses.pt_loss(lm_batch, lm_logits)
stats.update(pt_loss_stats)
loss = rl_loss + self.args.lm_coef * pt_loss
stats["tot_loss"] = loss.item()
forward_time = time() - forward_time
# backward
backward_time = time()
self.model.backward(loss)
backward_time = time() - backward_time
# step
step_time = time()
self.model.step()
step_time = time() - step_time
if self.args.gradient_checkpointing:
try: self.model.module.set_force_gradient_checkpointing(False)
except: self.model.module.base_model.set_force_gradient_checkpointing(False)
if self.iter_count % self.args.gradient_accumulation_steps == 0 and \
((self.global_iter_count < 10000 and (self.global_iter_count % 1000 == 0)) or \
self.global_iter_count % self.args.save_interval == 0):
self.save()
# eval
if self.iter_count % self.args.gradient_accumulation_steps == 0 and \
((self.global_iter_count < 1000 and (self.global_iter_count % 100 == 0)) or \
(self.global_iter_count % self.args.eval_interval == 0)):
self.evaluate()
elapsed_time = forward_time + backward_time + step_time
stats["elapsed_time"] = elapsed_time
for k in stats:
logging_stats[k] += stats[k]
# Logging
def get_log(log_stats, one_step_time):
keys = ["tot_loss", "rl_loss", "pt_loss", "pg_loss", "reg_loss", "reward", "rev_kl", "stu_lens", "mixed_lens"]
prefix = "train | data_epochs {:2d}/{:2d} | inner iter: {:3d}/{:3d} | ppo epoch: {:2d}/{:2d} | global iter: {:6d}/{:6d}".format(
self.sampler.epochs,
self.args.epochs,
it,
len(self.train_dataloader),
ppo_epoch,
self.n_updates_per_batch,
self.global_iter_count,
self.total_steps
)
suffix = "| lr: {:.4e} | scale: {:6.2f} | time: {:.3f} | step time: {:.3f}".format(
self.scheduler.get_last_lr()[0],
self.opt.cur_scale if hasattr(self.opt, "cur_scale") else 0,
elapsed_time,
one_step_time
)
for key in keys:
prefix += "| {}: {:.4f} ".format(key, log_stats.get(key, 0))
return prefix + suffix
mid_log_step = self.args.gradient_accumulation_steps // self.args.mid_log_num
mid_log_step = 1 if mid_log_step == 0 else mid_log_step
if self.iter_count % mid_log_step == 0:
print_rank(get_log(stats, 0))
if self.global_iter_count % self.args.log_interval == 0 and self.iter_count % self.args.gradient_accumulation_steps == 0:
logging_stats = {k:v/(self.args.log_interval*self.args.gradient_accumulation_steps) for k,v in logging_stats.items()}
log_str = get_log(logging_stats, logging_stats.get("elapsed_time", 0) * self.args.gradient_accumulation_steps)
print_rank("*" * 100)
print_rank(log_str)
print_rank(self.args.save)
print_rank("*" * 100)
save_rank(log_str, os.path.join(self.args.save, "log.txt"))
logging_stats = {k:0 for k in logging_stats}
# end
if (self.global_iter_count >= self.total_steps or self.sampler.epochs >= self.args.epochs):
if self.global_iter_count >= self.total_steps:
print_rank("Reached total steps {}/{}".format(self.global_iter_count, self.total_steps))
else:
print_rank("Reached data epochs {}/{}".format(self.sampler.epochs, self.args.epochs))
self.save()
results, preds, response_texts = self.evaluate_ppo()
if self.eval_lm_pipeline is not None:
eval_pt_results = self.evaluate_pt()
results.update(eval_pt_results)
self.save_evals(preds, results, response_texts)
return results
self.iter_count += 1
if self.iter_count % self.args.gradient_accumulation_steps == 0:
self.global_iter_count += 1
self.post_backward_callback()
self.post_epoch_callback(training_epoch)
def post_backward_callback(self):
pass
def post_epoch_callback(self, epoch):
self.store.clear_history()
# self.store.load(self.args.save)
self.sampler.run_sample(
self.args.num_rollouts_per_device, self.global_iter_count
) # Collect more rollouts for training
def prepare_learning(self):
self.train_dataloader = self.store.create_loader(
self.args.batch_size, shuffle=True, num_workers=self.args.num_workers, drop_last=True
)
self.eval_dataloader = self.eval_pipeline.create_loader(
self.args.batch_size, shuffle=False, num_workers=self.args.num_workers, drop_last=False)
self.lm_dataloader = self.lm_pipeline.create_loader(
self.args.batch_size, shuffle=True, num_workers=self.args.num_workers, drop_last=True)
self.lm_iterator = iter(self.lm_dataloader)
self.eval_lm_dataloader = self.eval_lm_pipeline.create_loader(
self.args.batch_size, shuffle=False, num_workers=self.args.num_workers, drop_last=False)
self.n_updates_per_batch = self.args.ppo_epochs
self.total_steps = int(
self.args.training_epochs
* self.n_updates_per_batch
* len(self.train_dataloader)
/ self.args.gradient_accumulation_steps
)
self.total_steps = min(self.total_steps, self.args.total_iters)
def evaluate(self):
eval_results = {}
eval_rl_results, preds, response_texts = self.evaluate_ppo()
eval_results.update(eval_rl_results)
eval_pt_results = self.evaluate_pt()
eval_results.update(eval_pt_results)
response_texts = response_texts[:len(self.eval_pipeline.ppo_answers)]
self.save_evals(preds, eval_results, response_texts)
if get_rank() == 0:
res = compute_metrics(response_texts, self.eval_pipeline.ppo_answers)
eval_results.update(res)
keys = ["rougeL", "exact_match", "rev_kl", "lens", "pt_loss", "lm_loss", "kd_loss"]
eval_log_str = "eval "
for key in keys:
eval_log_str += "| {}: {:.3f} ".format(key, eval_results[key])
print_rank(eval_log_str)
save_rank(eval_log_str, os.path.join(self.args.save, "log.txt"))
def evaluate_ppo(self): # noqa: C901
# self.model.eval()
"""Samples model on `eval_prompts`, logs stats with `reward_fn` or `metric_fn` if provided"""
stats = {}
all_full_ids = []
all_rev_kl = []
all_lens = []
table = []
with torch.no_grad():
for batch in tqdm(self.eval_dataloader, "Generation Evaluation", disable=(not get_rank() == 0)):
batch, no_model_batch = batch
batch, _ = self.eval_pipeline.move_to_device(batch, no_model_batch, self.device)
gen_out = self.generate(
**batch,
return_dict_in_generate=True,
output_scores=True
)
full_ids = gen_out.sequences
gen_logits = gen_out.scores # NOTE: [b, s, h_p]
inf_mask = torch.isinf(gen_logits)
all_full_ids.append(full_ids)
input_ids = batch["input_ids"]
gen_ids = full_ids[:, input_ids.size(1):]
mask = self.get_mask(full_ids)
mask = mask[:, input_ids.size(1)-1:input_ids.size(1)+gen_ids.size(1)-1]
lens = torch.sum(mask, dim=-1)
teacher_rewards = self.reward_fn(input_ids, gen_ids)["rewards"] # \log p(y_t | y_{<t}, x)
_, logprobs = self.compute_logits_and_log_probs(input_ids, gen_ids, inf_mask=inf_mask, base="base") # \log q_{\theta}(y_t | y_{<t}, x)
kl = get_rev_kl(teacher_rewards, logprobs, mask)
kl = kl.sum(-1)
if self.args.length_norm:
kl = kl / lens
all_rev_kl.append(kl)
all_lens.append(lens)
all_full_ids = torch.cat(all_full_ids, dim=0)
all_rev_kl = torch.cat(all_rev_kl, dim=0)
all_lens = torch.cat(all_lens, dim=0)
full_ids = all_gather(all_full_ids, dim=1, world_size=self.dp_world_size, group=self.dp_group, op="stack")
full_ids = full_ids.view(-1, full_ids.size(-1))
prompt_ids = full_ids[:, :self.eval_pipeline.max_prompt_length]
all_rev_kl = all_gather(all_rev_kl, dim=0, world_size=self.dp_world_size, group=self.dp_group)
stats["rev_kl"] = all_rev_kl.mean()
all_lens = all_gather(all_lens, dim=0, world_size=self.dp_world_size, group=self.dp_group)
stats["lens"] = all_lens.float().mean()
response_texts = []
if get_rank() == 0:
prompt_texts = self.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
response_texts = self.tokenizer.batch_decode(full_ids[:, self.eval_pipeline.max_prompt_length:], skip_special_tokens=True)
gen_texts = [p + g for p, g in zip(prompt_texts, response_texts)]
columns = ["prompts"]
columns_data = [prompt_texts]
# in online setting, compute the reward for validation
columns.append("samples")
if isinstance(gen_texts[0], str):
columns_data.append(gen_texts)
else:
columns_data.append(gen_texts.tolist())
table.append(list(zip(*columns_data)))
# Log and display evaluation metrics
if get_rank() == 0:
rows = sum(list(map(list, zip(*table))), [])
# Add metrics/rewards to the table's title
table_title = f"Evaluation #{self.nth_evaluation}"
for k, x in stats.items():
if k.startswith("reward") or k.startswith("metrics"):
table_title += f" {k}: {significant(x)}"
rich_table = Table(*columns, title=table_title, show_lines=True)
for ix in range(min(3, len(rows))):
rich_table.add_row(*[str(significant(x)) for x in rows[ix]])
try:
Console().print(rich_table)
except:
pass
self.nth_evaluation += 1
return stats, table, response_texts
def evaluate_pt(self):
all_pt_losses = []
all_lm_losses = []
all_kd_losses = []
for batch in tqdm(self.eval_lm_dataloader, desc="LM Evaluation", disable=(not get_rank() == 0)):
self.eval_lm_pipeline.move_to_device(*batch, self.device)
model_batch, _ = batch
outputs = self.model(**model_batch, return_dict=True, use_cache=False)
logits = outputs.logits
with torch.no_grad():
_, stats = self.losses.pt_loss(batch, logits)
all_pt_losses.append(stats["pt_loss"])
all_lm_losses.append(stats["lm_loss"])
all_kd_losses.append(stats["ds_loss"])
all_pt_losses = torch.tensor(all_pt_losses, device=self.device)
eval_pt_loss = all_gather(all_pt_losses, dim=0, world_size=self.dp_world_size, group=self.dp_group).mean().item()
all_lm_losses = torch.tensor(all_lm_losses, device=self.device)
eval_lm_loss = all_gather(all_lm_losses, dim=0, world_size=self.dp_world_size, group=self.dp_group).mean().item()
all_kd_losses = torch.tensor(all_kd_losses, device=self.device)
eval_kd_loss = all_gather(all_kd_losses, dim=0, world_size=self.dp_world_size, group=self.dp_group).mean().item()
results = {"pt_loss": eval_pt_loss, "lm_loss": eval_lm_loss, "kd_loss": eval_kd_loss}
return results
def save(self, directory: Optional[str] = None):
"""Creates a checkpoint of the optimizer, scheduler and model"""
"""Creates checkpoint of optimizer, scheduler and a model"""
base_ckpt_path = directory or self.args.save
ckpt_dir = os.path.join(base_ckpt_path, f"{self.global_iter_count}")
os.makedirs(ckpt_dir, exist_ok=True)
if self.args.model_parallel:
raise NotImplementedError
else:
if get_rank() == 0:
self.model.module.base_model.save_pretrained(ckpt_dir, safe_serialization=False)
# torch.save(self.model.module.value_model.state_dict(), os.path.join(ckpt_dir, "value_model.ckpt"))
print(f"Model save to {ckpt_dir}")
self.tokenizer.save_pretrained(ckpt_dir)
def save_evals(self, preds, results, response_texts, directory: Optional[str] = None):
"""Creates a checkpoint of the optimizer, scheduler and model"""
"""Creates checkpoint of optimizer, scheduler and a model"""
base_ckpt_path = directory or self.args.save
save_dir = os.path.join(base_ckpt_path, "eval", f"{self.global_iter_count}")
os.makedirs(save_dir, exist_ok=True)
if get_rank() == 0:
torch.save(preds, os.path.join(save_dir, "preds.pt"))
torch.save(results, os.path.join(save_dir, "results.pt"))
with open(os.path.join(save_dir, "answers.jsonl"), "w") as f:
for resp in response_texts:
f.write(json.dumps({"text": resp}) + "\n")
def push_to_store(self, data):
self.store.push(data)
def generate(self, input_ids, attention_mask=None, mode="base", teacher_mixed_sample=False, **kwargs):
"""Wraps hf's `generate` adding some specific method's defaults"""
input_ids = input_ids.to(self.device)
if attention_mask is not None:
attention_mask = attention_mask.to(self.device)
kwargs = dict(self.generate_kwargs, **kwargs)
if mode == "base":
model = self.model.module
elif mode == "teacher":
model = self.teacher_model
else:
raise NotImplementedError
mix_in_model, mix_in_alpha = None, None
if teacher_mixed_sample:
mix_in_model = self.teacher_model
mix_in_alpha = self.args.teacher_mixed_alpha
with torch.no_grad():
generation_config = GenerationConfig(**kwargs)
max_new_tokens = generation_config.max_length - input_ids.size(1)
gen = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
generation_config=generation_config,
max_new_tokens=max_new_tokens,
mix_in_model=mix_in_model,
mix_in_alpha=mix_in_alpha
)
gen.sequences = F.pad(
gen.sequences,
(0, self.max_length - gen.sequences.shape[1]),
value=self.tokenizer.pad_token_id,
)
if gen.scores is not None:
gen.scores = torch.stack(gen.scores, dim=1)
gen.scores = torch.cat([
gen.scores,
torch.zeros(
gen.scores.size(0),
self.max_length - self.args.max_prompt_length - gen.scores.size(1),
gen.scores.size(2),
device=gen.scores.device)],
dim=1)
# NOTE: scores: [b, s, h_p]
return gen
================================================
FILE: minillm/utils.py
================================================
import math
from enum import Enum
from numbers import Number
from typing import Tuple
import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR
from accelerate import init_empty_weights
from transformers import (
AutoModelForCausalLM,
AutoConfig,
)
def get_entropy(gen_logits, inf_mask, mask, model_parallel=False):
inf_mask = torch.isinf(gen_logits) | inf_mask
if model_parallel:
raise NotImplementedError
else:
full_probs = F.softmax(gen_logits, dim=-1, dtype=torch.float32)
full_logprobs = F.log_softmax(gen_logits, dim=-1, dtype=torch.float32)
full_logprobs = full_logprobs.masked_fill(inf_mask, 0)
ent = -torch.sum(full_probs * full_logprobs, dim=-1)
ent = ent * mask
return ent
def get_log_probs(logits, ids, mask, inf_mask=None, model_parallel=False):
if model_parallel:
raise NotImplementedError
else:
logprobs = F.log_softmax(logits, dim=-1)
if inf_mask is not None:
logprobs = logprobs.masked_fill(inf_mask, -float("inf"))
logprobs = torch.gather(logprobs, dim=-1, index=ids.unsqueeze(-1)).squeeze(-1)
logprobs = logprobs.masked_fill(~(mask.bool()), 0)
# we ensure that the selected logprobs are not inf or nan
assert all((~torch.isinf(logprobs.view(-1))) & (~torch.isnan(logprobs.view(-1))))
return logprobs
def get_x_entropy(logits_1, logits_2, inf_mask, mask, model_parallel=False):
inf_mask = torch.isinf(logits_1) | torch.isinf(logits_2) | inf_mask
if model_parallel:
raise NotImplementedError
else:
full_probs = F.softmax(logits_1, dim=-1, dtype=torch.float32)
full_logprobs = F.log_softmax(logits_2, dim=-1, dtype=torch.float32)
full_logprobs = full_logprobs.masked_fill(inf_mask, 0)
xent = -torch.sum(full_probs * full_logprobs, dim=-1)
xent = xent * mask
return xent
def get_rev_kl(log_p, log_q, mask):
log_ratio = (log_p - log_q) * mask
kl = log_ratio.float().exp() - 1 - log_ratio
return kl
def get_global_statistics(xs: torch.Tensor) -> Tuple[float, float, int]:
"""
Computes element-wise mean and variance of the tensor across processes
"""
sum_and_count = torch.tensor([xs.sum(), xs.numel()], device=xs.device)
dist.all_reduce(sum_and_count, dist.ReduceOp.SUM)
global_sum, count = sum_and_count
global_mean = global_sum / count
sum_var = torch.sum((xs - global_mean) ** 2)
dist.all_reduce(sum_var, dist.ReduceOp.SUM)
global_var = sum_var / count
return global_mean, global_var, count
def whiten(xs: torch.Tensor, shift_mean=True, distributed=True) -> torch.Tensor:
"""Whitens values"""
if distributed and dist.is_initialized():
mean, var, _ = get_global_statistics(xs)
else:
var, mean = torch.var_mean(xs)
whitened = (xs - mean) * torch.rsqrt(var + 1e-8)
if not shift_mean:
whitened += mean
return whitened
def significant(x: Number, ndigits=2) -> Number:
"""
Cut the number up to its `ndigits` after the most significant
"""
if isinstance(x, torch.Tensor):
x = x.item()
if not isinstance(x, Number) or x == 0:
return x
return round(x, ndigits - int(math.floor(math.log10(abs(x)))))
class OptimizerName(str, Enum):
"""Supported optimizer names"""
ADAM: str = "adam"
ADAMW: str = "adamw"
ADAM_8BIT_BNB: str = "adam_8bit_bnb"
ADAMW_8BIT_BNB: str = "adamw_8bit_bnb"
SGD: str = "sgd"
def get_optimizer_class(name: OptimizerName):
"""
Returns the optimizer class with the given name
Args:
name (str): Name of the optimizer as found in `OptimizerNames`
"""
if name == OptimizerName.ADAM:
return torch.optim.Adam
if name == OptimizerName.ADAMW:
return torch.optim.AdamW
if name == OptimizerName.SGD.value:
return torch.optim.SGD
supported_optimizers = [o.value for o in OptimizerName]
raise ValueError(
f"`{name}` is not a supported optimizer. "
f"Supported optimizers are: {supported_optimizers}"
)
class SchedulerName(str, Enum):
"""Supported scheduler names"""
COSINE_ANNEALING = "cosine_annealing"
LINEAR = "linear"
def get_scheduler_class(name: SchedulerName):
"""
Returns the scheduler class with the given name
"""
if name == SchedulerName.COSINE_ANNEALING:
return CosineAnnealingLR
if name == SchedulerName.LINEAR:
return LinearLR
supported_schedulers = [s.value for s in SchedulerName]
raise ValueError(
f"`{name}` is not a supported scheduler. "
f"Supported schedulers are: {supported_schedulers}"
)
================================================
FILE: rouge_metric.py
================================================
import string
import json
import os
import argparse
from rouge_score import rouge_scorer
from transformers import AutoTokenizer
default_rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
# adapted the flowing from Squad v1.1 evaluation, without removing the articles.
def normalize_answer(s):
"""Lower text and remove punctuation, and extra whitespace."""
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_punc(lower(s)))
def exact_match(prediction, ground_truth, xlingual=False):
return (normalize_answer(prediction) == normalize_answer(ground_truth))
def rouge(prediction, ground_truth, xlingual=False):
scorer = default_rouge_scorer
scores = scorer.score(prediction=prediction, target=ground_truth)
return scores["rougeL"].fmeasure
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths, xlingual=False):
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth, xlingual=xlingual)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def compute_metrics(predictions, references, xlingual=False):
# assert len(predictions) == len(references), f"# of predictions {len(predictions)} doesn't match # of references {len(references)}."
min_length = min(len((predictions)), len(references))
predictions = predictions[:min_length]
references = references[:min_length]
em, rougeL = 0, 0
for pred, gold in zip(predictions, references):
assert isinstance(gold, list)
em += metric_max_over_ground_truths(
exact_match, prediction=pred, ground_truths=gold, xlingual=xlingual
)
rougeL += metric_max_over_ground_truths(
rouge, prediction=pred, ground_truths=gold, xlingual=xlingual
)
em = 100.0 * em / len(references)
rougeL = 100.0 * rougeL / len(references)
metrics = {"exact_match": em, "rougeL": rougeL}
metrics = {k: round(v, 4) for k, v in metrics.items()}
return metrics
def compute_grouped_metrics(predictions, references, groups, xlingual=False):
assert len(predictions) == len(references) == len(groups)
examples_by_group = {}
for pred, gold, group in zip(predictions, references, groups):
if group not in examples_by_group:
examples_by_group[group] = []
examples_by_group[group].append((pred, gold))
results = {}
for group, group_examples in examples_by_group.items():
task_predictions, task_references = zip(*group_examples)
group_metrics = compute_metrics(task_predictions, task_references, xlingual=xlingual)
for metric, value in group_metrics.items():
results[f"{metric}_for_{group}"] = value
return results
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--prediction_file", required=True,
help="Jsonl file with each line corresponding to a prediction. "
"Each json object should have an `id` and a `prediction` key.")
parser.add_argument(
"--reference_file", required=True,
help="Jsonl file with each line corresponding to a reference. "
"Each json object should have an `id` and a `references` key. "
"`task_id`, `task_category` and `task_track` are optional, which will be used to "
"compute the per-task performance, per-category performance and the performance for default (english) / xlingual Tracks.")
parser.add_argument(
"--output_file",
help="Jsonl file to write the results to.")
parser.add_argument(
"--model_name",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
references = []
with open(args.reference_file) as fin:
for line in fin:
instance = json.loads(line)
if isinstance(instance["output"], list):
references.append(instance["output"])
else:
references.append([instance["output"]])
predictions = []
with open(args.prediction_file) as fin:
for line in fin:
prediction = json.loads(line)
predictions.append(prediction["text"])
predictions = predictions[:1000]
references = references[:len(predictions)]
results = compute_metrics(predictions, references, xlingual=False)
print(results)
if args.output_file:
os.makedirs(args.output_file, exist_ok=True)
with open(os.path.join(args.output_file, f"{args.model_name}.json"), "w") as fout:
json.dump(results, fout, indent=2)
================================================
FILE: scripts/gpt2/distillm/train_0.1B_1.5B.sh
================================================
#! /bin/bash
MASTER_ADDR=localhost
MASTER_PORT=${2-2012}
NNODES=1
NODE_RANK=0
GPUS_PER_NODE=${3-16}
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT"
# model
BASE_PATH=${1-"/home/MiniLLM"}
CKPT_NAME="gpt2-base"
CKPT="${BASE_PATH}/results/gpt2/train/init/${CKPT_NAME}"
TEACHER_CKPT_NAME="xlarge-sft"
TEACHER_CKPT="${BASE_PATH}/results/gpt2/train/sft/gpt2-xlarge/"
# data
DATA_DIR="${BASE_PATH}/processed_data/dolly/full/gpt2/"
LM_DATA_DIR="${BASE_PATH}/processed_data/openwebtext/gpt2/512/10M/"
# hp
BATCH_SIZE=8
LR=0.0005
GRAD_ACC=1
EVAL_BATCH_SIZE=64
# length
MAX_LENGTH=512
# runtime
SAVE_PATH="${BASE_PATH}/results/gpt2/train/distill_0.1B_1.5B_final2"
# seed
SEED=10
OPTS=""
# model
OPTS+=" --base-path ${BASE_PATH}"
OPTS+=" --model-path ${CKPT}"
OPTS+=" --teacher-model-path ${TEACHER_CKPT}"
OPTS+=" --ckpt-name ${CKPT_NAME}"
OPTS+=" --teacher-ckpt-name ${TEACHER_CKPT_NAME}"
OPTS+=" --teacher-model-fp16"
OPTS+=" --n-gpu ${GPUS_PER_NODE}"
# data
OPTS+=" --data-dir ${DATA_DIR}"
OPTS+=" --lm-data-dir ${LM_DATA_DIR}"
OPTS+=" --num-workers 4"
OPTS+=" --dev-num 1000"
# hp
OPTS+=" --lr ${LR}"
OPTS+=" --batch-size ${BATCH_SIZE}"
OPTS+=" --eval-batch-size ${EVAL_BATCH_SIZE}"
OPTS+=" --gradient-accumulation-steps ${GRAD_ACC}"
OPTS+=" --warmup-iters 0"
OPTS+=" --lr-decay-style cosine"
OPTS+=" --weight-decay 1e-2"
OPTS+=" --clip-grad 1.0"
OPTS+=" --epochs 20"
OPTS+=" --kd-ratio 1.0"
# length
OPTS+=" --max-length ${MAX_LENGTH}"
OPTS+=" --max-prompt-length 256"
# runtime
OPTS+=" --do-train"
OPTS+=" --do-valid"
OPTS+=" --eval-gen"
OPTS+=" --save-interval -1"
OPTS+=" --eval-interval -1"
OPTS+=" --log-interval 4"
OPTS+=" --mid-log-num -1"
OPTS+=" --save ${SAVE_PATH}"
# seed
OPTS+=" --seed ${SEED}"
# deepspeed
OPTS+=" --deepspeed"
OPTS+=" --deepspeed_config ${BASE_PATH}/configs/deepspeed/ds_config.json"
# type
OPTS+=" --type adaptive-sfkl"
# gen
OPTS+=" --do-sample"
OPTS+=" --top-k 0"
OPTS+=" --top-p 1.0"
OPTS+=" --temperature 1.0"
# distillm
OPTS+=" --student-gen"
OPTS+=" --gen-num-beams 1"
OPTS+=" --gen-top-p 1.0"
OPTS+=" --init-threshold 0.0"
OPTS+=" --loss-eps 0.1"
OPTS+=" --capacity 1000"
export NCCL_DEBUG=""
export WANDB_DISABLED=True
export TF_CPP_MIN_LOG_LEVEL=3
export PYTHONPATH=${BASE_PATH}
CMD="torchrun ${DISTRIBUTED_ARGS} ${BASE_PATH}/finetune.py ${OPTS} $@"
echo ${CMD}
echo "PYTHONPATH=${PYTHONPA
gitextract_j6ug111_/ ├── .gitignore ├── README.md ├── arguments.py ├── configs/ │ ├── deepspeed/ │ │ ├── ds_config.json │ │ ├── ds_config_fp32.json │ │ ├── ds_config_zero2.json │ │ └── ds_config_zero2_offload.json │ └── hostfiles/ │ ├── node_0_1 │ ├── node_0_1_2_3 │ ├── node_1_2 │ └── node_2_3 ├── data_utils/ │ ├── distributed_indexed.py │ ├── indexed_dataset.py │ ├── lm_datasets.py │ └── prompt_datasets.py ├── distillm/ │ ├── __init__.py │ ├── buffer.py │ ├── losses.py │ └── sampler.py ├── evaluate.py ├── evaluate_main.py ├── finetune.py ├── generate.py ├── install.sh ├── minillm/ │ ├── __init__.py │ ├── data_types.py │ ├── losses.py │ ├── model.py │ ├── pipelines.py │ ├── reward.py │ ├── sampler.py │ ├── storages.py │ ├── trainer.py │ └── utils.py ├── rouge_metric.py ├── scripts/ │ ├── gpt2/ │ │ ├── distillm/ │ │ │ ├── train_0.1B_1.5B.sh │ │ │ ├── train_0.3B_1.5B.sh │ │ │ └── train_0.7B_1.5B.sh │ │ ├── eval/ │ │ │ ├── eval_main_dolly.sh │ │ │ ├── eval_main_self_inst.sh │ │ │ ├── eval_main_sinst.sh │ │ │ ├── eval_main_uinst.sh │ │ │ ├── eval_main_vicuna.sh │ │ │ └── run_eval.sh │ │ ├── gkd/ │ │ │ ├── gkd_base.sh │ │ │ ├── gkd_large.sh │ │ │ └── gkd_medium.sh │ │ ├── imitkd/ │ │ │ ├── imitkd_base.sh │ │ │ ├── imitkd_large.sh │ │ │ └── imitkd_medium.sh │ │ ├── init/ │ │ │ ├── init_base.sh │ │ │ ├── init_large.sh │ │ │ └── init_medium.sh │ │ ├── kd/ │ │ │ ├── kd_base.sh │ │ │ ├── kd_large.sh │ │ │ └── kd_medium.sh │ │ ├── minillm/ │ │ │ ├── train_base_xl.sh │ │ │ ├── train_large_xl.sh │ │ │ └── train_medium_xl.sh │ │ ├── seqkd/ │ │ │ ├── seqkd_base.sh │ │ │ ├── seqkd_large.sh │ │ │ └── seqkd_medium.sh │ │ ├── sft/ │ │ │ ├── sft_base.sh │ │ │ ├── sft_large.sh │ │ │ ├── sft_medium.sh │ │ │ └── sft_xlarge.sh │ │ └── tools/ │ │ ├── generate_data_seqkd.sh │ │ ├── process_data_dolly.sh │ │ ├── process_data_pretrain.sh │ │ └── process_pseudo_data_seqkd.sh │ ├── openllama2/ │ │ ├── distillm/ │ │ │ └── train_3B_7B_teacher_lora.sh │ │ ├── eval/ │ │ │ ├── eval_main_dolly_lora.sh │ │ │ ├── eval_main_self_inst_lora.sh │ │ │ ├── eval_main_sinst_lora.sh │ │ │ ├── eval_main_uinst_lora.sh │ │ │ ├── eval_main_vicuna_lora.sh │ │ │ └── run_eval.sh │ │ ├── gkd/ │ │ │ └── gkd_3B_7B_teacher_lora.sh │ │ ├── imitkd/ │ │ │ └── imitkd_3B_7B_teacher_lora.sh │ │ ├── init/ │ │ │ └── sft_3B_lora.sh │ │ ├── kd/ │ │ │ └── kd_3B_7B_teacher_lora.sh │ │ ├── minillm/ │ │ │ └── train_3B_7B_lora.sh │ │ ├── seqkd/ │ │ │ └── seqkd_3B_7B_teacher_lora.sh │ │ ├── sft/ │ │ │ ├── sft_3B_lora.sh │ │ │ └── sft_7B_lora.sh │ │ └── tools/ │ │ ├── generate_data_seqkd.sh │ │ ├── process_data_dolly.sh │ │ ├── process_data_pretrain.sh │ │ └── process_pseudo_data_seqkd.sh │ └── opt/ │ ├── distillm/ │ │ ├── train_0.1B_2.7B.sh │ │ ├── train_0.3B_2.7B.sh │ │ └── train_1.3B_2.7B.sh │ ├── eval/ │ │ ├── eval_main_dolly.sh │ │ ├── eval_main_self_inst.sh │ │ ├── eval_main_sinst.sh │ │ ├── eval_main_uinst.sh │ │ ├── eval_main_vicuna.sh │ │ └── run_eval.sh │ ├── gkd/ │ │ ├── gkd_0.1B_2.7B.sh │ │ ├── gkd_0.3B_2.7B.sh │ │ └── gkd_1.3B_2.7B.sh │ ├── imitkd/ │ │ ├── imitkd_0.1B_2.7B.sh │ │ ├── imitkd_0.3B_2.7B.sh │ │ └── imitkd_1.3B_2.7B.sh │ ├── init/ │ │ ├── init_0.1B.sh │ │ ├── init_0.3B.sh │ │ └── init_1.3B.sh │ ├── kd/ │ │ ├── kd_0.1B_2.7B.sh │ │ ├── kd_0.3B_2.7B.sh │ │ └── kd_1.3B_2.7B.sh │ ├── minillm/ │ │ ├── train_0.1B_2.7B.sh │ │ ├── train_0.3B_2.7B.sh │ │ └── train_1.3B_2.7B.sh │ ├── seqkd/ │ │ ├── seqkd_0.1B_2.7B.sh │ │ ├── seqkd_0.3B_2.7B.sh │ │ └── seqkd_1.3B_2.7B.sh │ ├── sft/ │ │ ├── sft_0.1B.sh │ │ ├── sft_0.3B.sh │ │ ├── sft_1.3B.sh │ │ └── sft_2.7B.sh │ └── tools/ │ ├── generate_data_seqkd.sh │ ├── process_data_dolly.sh │ ├── process_data_pretrain.sh │ └── process_pseudo_data_seqkd.sh ├── tools/ │ ├── convert_mp.py │ ├── get_openwebtext.py │ ├── process_data_dolly.py │ └── process_data_pretrain.py ├── train_minillm.py └── utils.py
SYMBOL INDEX (283 symbols across 28 files)
FILE: arguments.py
function add_model_args (line 22) | def add_model_args(parser: argparse.ArgumentParser):
function add_runtime_args (line 43) | def add_runtime_args(parser: argparse.ArgumentParser):
function add_data_args (line 68) | def add_data_args(parser: argparse.ArgumentParser):
function add_hp_args (line 99) | def add_hp_args(parser: argparse.ArgumentParser):
function add_ppo_args (line 148) | def add_ppo_args(parser: argparse.ArgumentParser):
function add_minillm_args (line 163) | def add_minillm_args(parser: argparse.ArgumentParser):
function add_distillm_args (line 174) | def add_distillm_args(parser: argparse.ArgumentParser):
function add_gen_args (line 197) | def add_gen_args(parser: argparse.ArgumentParser):
function add_peft_args (line 211) | def add_peft_args(parser: argparse.ArgumentParser):
function get_args (line 225) | def get_args():
FILE: data_utils/distributed_indexed.py
function code (line 41) | def code(dtype):
function index_file_path (line 48) | def index_file_path(prefix_path):
function data_file_path (line 52) | def data_file_path(prefix_path):
class DistributedMMapIndexedDataset (line 56) | class DistributedMMapIndexedDataset(torch.utils.data.Dataset):
class Index (line 57) | class Index(object):
method __init__ (line 59) | def __init__(self, path):
method __del__ (line 89) | def __del__(self):
method dtype (line 94) | def dtype(self):
method sizes (line 98) | def sizes(self):
method doc_idx (line 102) | def doc_idx(self):
method __getitem__ (line 105) | def __getitem__(self, i):
method __len__ (line 108) | def __len__(self):
method __init__ (line 111) | def __init__(self, path, name, rank_number, rank_total, cache = None):
method _probe_data_path (line 133) | def _probe_data_path(self, path, name, rank_total):
method __getstate__ (line 150) | def __getstate__(self):
method __setstate__ (line 153) | def __setstate__(self, state):
method _do_init (line 157) | def _do_init(self, path, name, cache, state):
method __del__ (line 171) | def __del__(self):
method __len__ (line 178) | def __len__(self):
method _next_file (line 181) | def _next_file(self):
method __relative_idx (line 188) | def __relative_idx(self, idx):
method __slice_item (line 192) | def __slice_item(self, start, stop):
method __getitem__ (line 199) | def __getitem__(self, idx):
method sizes (line 209) | def sizes(self):
method exists (line 212) | def exists(self, path):
FILE: data_utils/indexed_dataset.py
function __best_fitting_dtype (line 23) | def __best_fitting_dtype(vocab_size=None):
function get_available_dataset_impl (line 30) | def get_available_dataset_impl():
function infer_dataset_impl (line 34) | def infer_dataset_impl(path):
function make_builder (line 50) | def make_builder(out_file, impl, dtype):
function make_dataset (line 57) | def make_dataset(path, impl, skip_warmup=False):
function dataset_exists (line 74) | def dataset_exists(path, impl):
function read_longs (line 81) | def read_longs(f, n):
function write_longs (line 87) | def write_longs(f, a):
function code (line 104) | def code(dtype):
function index_file_path (line 111) | def index_file_path(prefix_path):
function data_file_path (line 115) | def data_file_path(prefix_path):
function create_doc_idx (line 119) | def create_doc_idx(sizes):
class IndexedDataset (line 127) | class IndexedDataset(torch.utils.data.Dataset):
method __init__ (line 131) | def __init__(self, path):
method read_index (line 137) | def read_index(self, path):
method read_data (line 155) | def read_data(self, path):
method check_index (line 158) | def check_index(self, i):
method __del__ (line 162) | def __del__(self):
method __getitem__ (line 167) | def __getitem__(self, idx):
method __len__ (line 191) | def __len__(self):
method num_tokens (line 194) | def num_tokens(self, index):
method size (line 197) | def size(self, index):
method exists (line 201) | def exists(path):
method supports_prefetch (line 207) | def supports_prefetch(self):
class IndexedCachedDataset (line 211) | class IndexedCachedDataset(IndexedDataset):
method __init__ (line 213) | def __init__(self, path):
method supports_prefetch (line 219) | def supports_prefetch(self):
method prefetch (line 222) | def prefetch(self, indices):
method __getitem__ (line 247) | def __getitem__(self, idx):
class IndexedDatasetBuilder (line 264) | class IndexedDatasetBuilder(object):
method __init__ (line 275) | def __init__(self, out_file, dtype=np.int32):
method add_item (line 284) | def add_item(self, tensor):
method end_document (line 291) | def end_document(self):
method merge_file_ (line 294) | def merge_file_(self, another_file):
method finalize (line 314) | def finalize(self, index_file):
function _warmup_mmap_file (line 329) | def _warmup_mmap_file(path):
class MMapIndexedDataset (line 335) | class MMapIndexedDataset(torch.utils.data.Dataset):
class Index (line 336) | class Index(object):
method writer (line 340) | def writer(cls, path, dtype):
method __init__ (line 385) | def __init__(self, path, skip_warmup=False):
method __del__ (line 422) | def __del__(self):
method dtype (line 427) | def dtype(self):
method sizes (line 431) | def sizes(self):
method doc_idx (line 435) | def doc_idx(self):
method __getitem__ (line 439) | def __getitem__(self, i):
method __len__ (line 442) | def __len__(self):
method __init__ (line 445) | def __init__(self, path, skip_warmup=False):
method __getstate__ (line 454) | def __getstate__(self):
method __setstate__ (line 457) | def __setstate__(self, state):
method _do_init (line 460) | def _do_init(self, path, skip_warmup):
method __del__ (line 472) | def __del__(self):
method __len__ (line 477) | def __len__(self):
method __getitem__ (line 481) | def __getitem__(self, idx):
method get (line 501) | def get(self, idx, offset=0, length=None):
method sizes (line 516) | def sizes(self):
method supports_prefetch (line 530) | def supports_prefetch(self):
method exists (line 534) | def exists(path):
class MMapIndexedDatasetBuilder (line 540) | class MMapIndexedDatasetBuilder(object):
method __init__ (line 541) | def __init__(self, out_file, dtype=np.int64):
method add_item (line 547) | def add_item(self, tensor):
method end_document (line 552) | def end_document(self):
method merge_file_ (line 555) | def merge_file_(self, another_file):
method finalize (line 567) | def finalize(self, index_file):
FILE: data_utils/lm_datasets.py
class LMTrainDataset (line 15) | class LMTrainDataset(Dataset):
method __init__ (line 16) | def __init__(self, args, tokenizer, path, split, num, ratio, rng_sampl...
method __len__ (line 40) | def __len__(self):
method __getitem__ (line 43) | def __getitem__(self, index):
method _get_lm (line 46) | def _get_lm(self, index):
method _process_lm (line 53) | def _process_lm(self, i, samp, model_data, no_model_data, gen_data):
method move_to_device (line 77) | def move_to_device(self, model_data, no_model_data, gen_data, device):
method collate (line 89) | def collate(self, samples):
FILE: data_utils/prompt_datasets.py
class PromptDataset (line 13) | class PromptDataset(Dataset):
method __init__ (line 14) | def __init__(self, args, tokenizer, split, data_path=None, num=-1):
method __len__ (line 50) | def __len__(self):
method load_data_json (line 53) | def load_data_json(self, data_path):
method load_data_txt (line 80) | def load_data_txt(self, data_path):
method verbalizer (line 93) | def verbalizer(self):
method __getitem__ (line 96) | def __getitem__(self, index: int):
method collate (line 114) | def collate(self, samples):
method move_to_device (line 141) | def move_to_device(self, model_batch, no_model_batch, device):
FILE: distillm/buffer.py
class ReplayBuffer (line 16) | class ReplayBuffer:
method __init__ (line 17) | def __init__(self, args):
method __len__ (line 28) | def __len__(self):
method sample (line 31) | def sample(self):
method move_to_device (line 54) | def move_to_device(self, model_data, no_model_data, device):
method move_to_memory (line 63) | def move_to_memory(self, model_data, no_model_data):
FILE: distillm/losses.py
function forward_kl (line 4) | def forward_kl(logits, teacher_logits, no_model_batch):
function reverse_kl (line 14) | def reverse_kl(logits, teacher_logits, no_model_batch):
function symmetric_kl (line 26) | def symmetric_kl(logits, teacher_logits, no_model_batch, lam=0.9):
function js_distance (line 32) | def js_distance(logits, teacher_logits, no_model_batch, lam=0.9):
function tv_distance (line 55) | def tv_distance(logits, teacher_logits, no_model_batch):
function skewed_forward_kl (line 66) | def skewed_forward_kl(logits, teacher_logits, no_model_batch, lam=0.1):
function skewed_reverse_kl (line 80) | def skewed_reverse_kl(logits, teacher_logits, no_model_batch, lam=0.1):
FILE: distillm/sampler.py
class SampleGenerator (line 6) | class SampleGenerator():
method __init__ (line 7) | def __init__(self, args, tokenizer):
method run_sample (line 26) | def run_sample(self, model, gen_data):
FILE: evaluate.py
function setup_model (line 23) | def setup_model(args, ds_config, device):
function main (line 44) | def main():
FILE: evaluate_main.py
function prepare_dataset_main (line 22) | def prepare_dataset_main(args, tokenizer):
function run_model (line 29) | def run_model(args, tokenizer, model, dataset: PromptDataset, epoch, dev...
function evaluate_main (line 124) | def evaluate_main(args, tokenizer, model, dataset: PromptDataset, split,...
FILE: finetune.py
function get_teacher_model (line 50) | def get_teacher_model(args, device):
function get_optimizer (line 77) | def get_optimizer(args, model):
function get_learning_rate_scheduler (line 95) | def get_learning_rate_scheduler(args, optimizer):
function setup_model_and_optimizer (line 119) | def setup_model_and_optimizer(args, ds_config, device, set_optim=True):
function prepare_dataset (line 143) | def prepare_dataset(args, tokenizer):
function pt_loss (line 162) | def pt_loss(args, model, model_batch, no_model_batch):
function get_distil_loss (line 171) | def get_distil_loss(args, tokenizer, model, teacher_model, model_batch, ...
function get_teacher_lm_loss (line 196) | def get_teacher_lm_loss(args, tokenizer, model, teacher_model, model_bat...
function finetune (line 238) | def finetune(args, tokenizer: AutoTokenizer, model: deepspeed.DeepSpeedE...
function evaluate (line 431) | def evaluate(args, tokenizer, model, dataset: LMTrainDataset, split, epo...
function main (line 538) | def main():
FILE: generate.py
function setup_model (line 28) | def setup_model(args, ds_config, device):
function prepare_dataset (line 48) | def prepare_dataset(args, tokenizer):
function generate (line 55) | def generate(args, tokenizer, model, dataset, device):
function main (line 123) | def main():
FILE: minillm/__init__.py
function train (line 10) | def train(
FILE: minillm/data_types.py
class PromptElement (line 7) | class PromptElement:
class PromptBatch (line 23) | class PromptBatch:
class PPORLElement (line 39) | class PPORLElement:
class PPORLBatch (line 80) | class PPORLBatch:
FILE: minillm/losses.py
class Loss (line 15) | class Loss():
method __init__ (line 16) | def __init__(self, args, trainer):
method _get_cumsum_rewards (line 20) | def _get_cumsum_rewards(self, rewards):
method _get_advantages_and_returns (line 27) | def _get_advantages_and_returns(
method _pg_loss (line 58) | def _pg_loss(
method _reg_loss (line 98) | def _reg_loss(self, query_ids, response_ids, mask, logits, inf_mask, s...
method get_input_batch (line 110) | def get_input_batch(self, ppo_batch: PPORLBatch, pt_batch):
method ppo_loss (line 122) | def ppo_loss(self, batch: PPORLBatch, logits):
method pt_loss (line 194) | def pt_loss(self, batch, logits):
FILE: minillm/model.py
class PPOModel (line 8) | class PPOModel(nn.Module):
method __init__ (line 9) | def __init__(self, args, device):
method forward (line 16) | def forward(self, **x):
method generate (line 20) | def generate(self, **x):
method set_force_gradient_checkpointing (line 23) | def set_force_gradient_checkpointing(self, value):
FILE: minillm/pipelines.py
class PPOPipeline (line 15) | class PPOPipeline():
method __init__ (line 16) | def __init__(self, args, tokenizer, split, ppo_data_path=None, fix_pro...
method __len__ (line 41) | def __len__(self):
method __getitem__ (line 44) | def __getitem__(self, index: int):
method collate (line 60) | def collate(self, samples):
method move_to_device (line 88) | def move_to_device(self, model_batch, no_model_batch, device):
method create_loader (line 96) | def create_loader(self, batch_size: int, shuffle=False, drop_last: boo...
class LMPipeline (line 110) | class LMPipeline():
method __init__ (line 111) | def __init__(self, args, tokenizer, split, lm_data_path=None, num=-1):
method __len__ (line 126) | def __len__(self):
method __getitem__ (line 129) | def __getitem__(self, index):
method _get_lm (line 132) | def _get_lm(self, index):
method _process_lm (line 139) | def _process_lm(self, i, samp, model_data, no_model_data):
method move_to_device (line 157) | def move_to_device(self, model_batch, no_model_batch, device):
method collate (line 166) | def collate(self, samples):
method create_loader (line 189) | def create_loader(self, batch_size: int, shuffle=False, drop_last: boo...
FILE: minillm/reward.py
class Reward (line 8) | class Reward():
method __init__ (line 9) | def __init__(self, args, tokenizer: AutoTokenizer, model: AutoModelFor...
method get_input_batch (line 16) | def get_input_batch(self, input_ids, gen_ids, output_pos=True):
method reward_fn (line 33) | def reward_fn(self, input_ids, gen_ids, inf_mask=None, output_pos=True):
FILE: minillm/sampler.py
class PPOSampler (line 12) | class PPOSampler():
method __init__ (line 18) | def __init__(
method run_sample (line 39) | def run_sample(self, num_rollouts_per_device: int = 1024, iter_count: ...
FILE: minillm/storages.py
class BaseRolloutStore (line 17) | class BaseRolloutStore(Dataset):
method __init__ (line 18) | def __init__(self, capacity=-1):
method push (line 23) | def push(self, exps: Iterable[Any]):
method __getitem__ (line 29) | def __getitem__(self, index: int) -> PPORLElement:
method __len__ (line 32) | def __len__(self) -> int:
method create_loader (line 36) | def create_loader(
method broadcast (line 53) | def broadcast(self, batch, src=0, group=None):
method move_to_device (line 57) | def move_to_device(self, batch, device):
class PPORolloutStorage (line 61) | class PPORolloutStorage(BaseRolloutStore):
method __init__ (line 66) | def __init__(self, pad_token_id, seed):
method push (line 74) | def push(self, exps: Iterable[PPORLElement]):
method save (line 77) | def save(self, path):
method load (line 85) | def load(self, path):
method clear_history (line 89) | def clear_history(self):
method export_history (line 92) | def export_history(self, location: str):
method __getitem__ (line 104) | def __getitem__(self, index: int) -> PPORLElement:
method __len__ (line 107) | def __len__(self) -> int:
method collate (line 110) | def collate(self, elems: Iterable[PPORLElement]):
method create_loader (line 170) | def create_loader(self, batch_size: int, shuffle=False, drop_last: boo...
method broadcast (line 177) | def broadcast(self, batch: PPORLBatch, src=0, group=None):
method move_to_device (line 181) | def move_to_device(self, batch: PPORLBatch, device):
FILE: minillm/trainer.py
class PPOTrainer (line 43) | class PPOTrainer():
method __init__ (line 48) | def __init__(self, args, tokenizer: AutoTokenizer, reward_fn, ds_config):
method set_teacher_model (line 94) | def set_teacher_model(self, model):
method set_sampler (line 97) | def set_sampler(self, sampler):
method setup_optimizer (line 100) | def setup_optimizer(self):
method setup_scheduler (line 114) | def setup_scheduler(self):
method setup_ds (line 128) | def setup_ds(self, model, optimizer=None, scheduler=None):
method add_eval_pipeline (line 139) | def add_eval_pipeline(self, eval_pipeline: PPOPipeline):
method add_lm_pipeline (line 143) | def add_lm_pipeline(self, lm_pipeline: LMPipeline, eval_lm_pipeline: L...
method get_model_inputs (line 147) | def get_model_inputs(
method get_mask (line 170) | def get_mask(self, tokens):
method forward_model (line 176) | def forward_model(self, batch):
method compute_logits_and_log_probs (line 184) | def compute_logits_and_log_probs(self, query_ids, response_ids, inf_ma...
method train (line 220) | def train(self):
method post_backward_callback (line 372) | def post_backward_callback(self):
method post_epoch_callback (line 375) | def post_epoch_callback(self, epoch):
method prepare_learning (line 382) | def prepare_learning(self):
method evaluate (line 406) | def evaluate(self):
method evaluate_ppo (line 426) | def evaluate_ppo(self): # noqa: C901
method evaluate_pt (line 522) | def evaluate_pt(self):
method save (line 550) | def save(self, directory: Optional[str] = None):
method save_evals (line 565) | def save_evals(self, preds, results, response_texts, directory: Option...
method push_to_store (line 579) | def push_to_store(self, data):
method generate (line 582) | def generate(self, input_ids, attention_mask=None, mode="base", teache...
FILE: minillm/utils.py
function get_entropy (line 19) | def get_entropy(gen_logits, inf_mask, mask, model_parallel=False):
function get_log_probs (line 32) | def get_log_probs(logits, ids, mask, inf_mask=None, model_parallel=False):
function get_x_entropy (line 48) | def get_x_entropy(logits_1, logits_2, inf_mask, mask, model_parallel=Fal...
function get_rev_kl (line 61) | def get_rev_kl(log_p, log_q, mask):
function get_global_statistics (line 67) | def get_global_statistics(xs: torch.Tensor) -> Tuple[float, float, int]:
function whiten (line 82) | def whiten(xs: torch.Tensor, shift_mean=True, distributed=True) -> torch...
function significant (line 95) | def significant(x: Number, ndigits=2) -> Number:
class OptimizerName (line 108) | class OptimizerName(str, Enum):
function get_optimizer_class (line 118) | def get_optimizer_class(name: OptimizerName):
class SchedulerName (line 138) | class SchedulerName(str, Enum):
function get_scheduler_class (line 145) | def get_scheduler_class(name: SchedulerName):
FILE: rouge_metric.py
function normalize_answer (line 12) | def normalize_answer(s):
function exact_match (line 28) | def exact_match(prediction, ground_truth, xlingual=False):
function rouge (line 32) | def rouge(prediction, ground_truth, xlingual=False):
function metric_max_over_ground_truths (line 38) | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths, ...
function compute_metrics (line 46) | def compute_metrics(predictions, references, xlingual=False):
function compute_grouped_metrics (line 69) | def compute_grouped_metrics(predictions, references, groups, xlingual=Fa...
function parse_args (line 87) | def parse_args():
FILE: tools/convert_mp.py
function main (line 24) | def main():
FILE: tools/process_data_dolly.py
class Encoder (line 15) | class Encoder(object):
method __init__ (line 16) | def __init__(self, args):
method initializer (line 19) | def initializer(self):
method encode (line 22) | def encode(self, line):
function main (line 65) | def main():
FILE: tools/process_data_pretrain.py
class Encoder (line 14) | class Encoder(object):
method __init__ (line 15) | def __init__(self, args):
method initializer (line 18) | def initializer(self):
method encode (line 21) | def encode(self, line):
function main (line 28) | def main():
FILE: train_minillm.py
function get_teacher_model (line 19) | def get_teacher_model(args, device):
function main (line 44) | def main():
FILE: utils.py
function print_args (line 24) | def print_args(args):
function save_rank (line 33) | def save_rank(log_str, save_path, rank=0):
function print_rank (line 39) | def print_rank(*args, rank=0, **kwargs):
function all_gather (line 45) | def all_gather(t, dim=0, world_size=None, group=None, op="cat"):
function set_random_seed (line 58) | def set_random_seed(seed, mp=False):
function init_distributed (line 69) | def init_distributed(args):
function init_distributed_ds (line 86) | def init_distributed_ds(args):
function initialize (line 103) | def initialize(args):
function get_model (line 120) | def get_model(args, device):
function get_optimizer_params (line 176) | def get_optimizer_params(args, model: nn.Module):
function get_optimizer_params_peft (line 190) | def get_optimizer_params_peft(args, model: nn.Module):
function get_tokenizer (line 200) | def get_tokenizer(args):
function load_parallel (line 212) | def load_parallel(model, load_dir):
function save_parallel (line 222) | def save_parallel(model, save_dir):
Condensed preview — 130 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (432K chars).
[
{
"path": ".gitignore",
"chars": 3813,
"preview": "# Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks\n# Edit at https://www.toptal.com/de"
},
{
"path": "README.md",
"chars": 9339,
"preview": "# DistiLLM: Towards Streamlined Distillation for Large Language Models (ICML 2024)\n\n<a href=\"https://arxiv.org/abs/2402."
},
{
"path": "arguments.py",
"chars": 15168,
"preview": "# coding=utf-8\n# Copyright 2020 The OpenBMB team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2."
},
{
"path": "configs/deepspeed/ds_config.json",
"chars": 377,
"preview": "{\n \"train_micro_batch_size_per_gpu\": 1,\n \"gradient_accumulation_steps\": 1,\n \"zero_optimization\": {\n \"sta"
},
{
"path": "configs/deepspeed/ds_config_fp32.json",
"chars": 258,
"preview": "{\n \"train_micro_batch_size_per_gpu\": 1,\n \"gradient_accumulation_steps\": 1,\n \"zero_optimization\": {\n \"sta"
},
{
"path": "configs/deepspeed/ds_config_zero2.json",
"chars": 659,
"preview": "{\n \"train_micro_batch_size_per_gpu\": 1,\n \"gradient_accumulation_steps\": 1,\n \"zero_optimization\": {\n \"sta"
},
{
"path": "configs/deepspeed/ds_config_zero2_offload.json",
"chars": 700,
"preview": "{\n \"train_micro_batch_size_per_gpu\": 1,\n \"gradient_accumulation_steps\": 1,\n \"zero_optimization\": {\n \"sta"
},
{
"path": "configs/hostfiles/node_0_1",
"chars": 29,
"preview": "node-0 slots=8\nnode-1 slots=8"
},
{
"path": "configs/hostfiles/node_0_1_2_3",
"chars": 59,
"preview": "node-0 slots=8\nnode-1 slots=8\nnode-2 slots=8\nnode-3 slots=8"
},
{
"path": "configs/hostfiles/node_1_2",
"chars": 29,
"preview": "node-1 slots=8\nnode-2 slots=8"
},
{
"path": "configs/hostfiles/node_2_3",
"chars": 29,
"preview": "node-2 slots=8\nnode-3 slots=8"
},
{
"path": "data_utils/distributed_indexed.py",
"chars": 7176,
"preview": "# coding=utf-8\n# Copyright 2020 The OpenBMB team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2."
},
{
"path": "data_utils/indexed_dataset.py",
"chars": 18552,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MIT license found in the\n#"
},
{
"path": "data_utils/lm_datasets.py",
"chars": 4283,
"preview": "import random\nimport torch\nimport os\nimport json\nimport pickle\nimport numpy as np\nfrom torch.utils.data import Dataset\nf"
},
{
"path": "data_utils/prompt_datasets.py",
"chars": 5882,
"preview": "import random\nimport torch\nimport os\nfrom torch.utils.data import Dataset\nfrom .distributed_indexed import DistributedMM"
},
{
"path": "distillm/__init__.py",
"chars": 210,
"preview": "from .losses import forward_kl, reverse_kl, symmetric_kl, js_distance, tv_distance\nfrom .losses import skewed_forward_kl"
},
{
"path": "distillm/buffer.py",
"chars": 3053,
"preview": "import random\nimport torch\nimport os\nimport json\nimport pickle\nimport numpy as np\nfrom torch.utils.data import Dataset\n\n"
},
{
"path": "distillm/losses.py",
"chars": 4906,
"preview": "import torch\nimport torch.nn.functional as F\n\ndef forward_kl(logits, teacher_logits, no_model_batch):\n teacher_probs "
},
{
"path": "distillm/sampler.py",
"chars": 2812,
"preview": "import torch\nimport os\nfrom transformers import GenerationConfig\n\n\nclass SampleGenerator():\n def __init__(self, args,"
},
{
"path": "evaluate.py",
"chars": 2367,
"preview": "import time\nimport os\n\nimport torch\nimport torch.distributed as dist\nimport deepspeed\n\nimport json\n\nfrom arguments impor"
},
{
"path": "evaluate_main.py",
"chars": 6661,
"preview": "from data_utils.prompt_datasets import PromptDataset\nfrom transformers import GenerationConfig\nimport os\nimport nltk\nnlt"
},
{
"path": "finetune.py",
"chars": 24409,
"preview": "import time\nimport os\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch."
},
{
"path": "generate.py",
"chars": 5155,
"preview": "import time\nimport os\n\nimport torch\nimport torch.distributed as dist\nfrom torch.utils.data import DataLoader, Distribute"
},
{
"path": "install.sh",
"chars": 420,
"preview": "export NCCL_DEBUG=\"\"\n# conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -"
},
{
"path": "minillm/__init__.py",
"chars": 1588,
"preview": "from deepspeed import DeepSpeedConfig\nfrom typing import Optional\n\n# from trlx.utils.loading import get_orchestrator, ge"
},
{
"path": "minillm/data_types.py",
"chars": 3589,
"preview": "from dataclasses import dataclass\nfrom typing import Iterable\nfrom torchtyping import TensorType\n\n\n@dataclass\nclass Prom"
},
{
"path": "minillm/losses.py",
"chars": 9137,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom typing import Optional, Tuple\nfrom torchtyping i"
},
{
"path": "minillm/model.py",
"chars": 722,
"preview": "import torch.nn as nn\nfrom transformers import (\n AutoConfig,)\n\nfrom utils import get_model\n\n\nclass PPOModel(nn.Modul"
},
{
"path": "minillm/pipelines.py",
"chars": 8225,
"preview": "import os\nimport json\nimport torch\nimport random\nimport numpy as np\nfrom torch.utils.data import DataLoader, Distributed"
},
{
"path": "minillm/reward.py",
"chars": 2872,
"preview": "import torch\nfrom transformers import (\n AutoModelForCausalLM,\n AutoTokenizer,\n mpu)\n\n\nclass Reward():\n def "
},
{
"path": "minillm/sampler.py",
"chars": 8436,
"preview": "import torch\nimport os\n\nfrom .data_types import PromptBatch, PPORLElement\nfrom .pipelines import PPOPipeline\nfrom .train"
},
{
"path": "minillm/storages.py",
"chars": 5807,
"preview": "import json\nimport os\nimport time\nfrom abc import abstractmethod\nfrom typing import Any, Callable, Iterable\n\nimport torc"
},
{
"path": "minillm/trainer.py",
"chars": 26448,
"preview": "import json\nimport os\nimport deepspeed\nfrom time import time\nfrom typing import Optional, Tuple\nfrom collections import "
},
{
"path": "minillm/utils.py",
"chars": 4793,
"preview": "import math\nfrom enum import Enum\nfrom numbers import Number\nfrom typing import Tuple\n\nimport torch\nimport torch.nn.func"
},
{
"path": "rouge_metric.py",
"chars": 4902,
"preview": "import string\nimport json\nimport os\nimport argparse\nfrom rouge_score import rouge_scorer\nfrom transformers import AutoTo"
},
{
"path": "scripts/gpt2/distillm/train_0.1B_1.5B.sh",
"chars": 2554,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/distillm/train_0.3B_1.5B.sh",
"chars": 2593,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/distillm/train_0.7B_1.5B.sh",
"chars": 2592,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/eval/eval_main_dolly.sh",
"chars": 1639,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
},
{
"path": "scripts/gpt2/eval/eval_main_self_inst.sh",
"chars": 1647,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
},
{
"path": "scripts/gpt2/eval/eval_main_sinst.sh",
"chars": 1669,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
},
{
"path": "scripts/gpt2/eval/eval_main_uinst.sh",
"chars": 1672,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
},
{
"path": "scripts/gpt2/eval/eval_main_vicuna.sh",
"chars": 1641,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
},
{
"path": "scripts/gpt2/eval/run_eval.sh",
"chars": 802,
"preview": "#!/bin/bash\n\nMASTER_PORT=2040\nDEVICE=${1}\nckpt=${2}\n\nfor seed in 10 20 30 40 50\ndo\n CUDA_VISIBLE_DEVICES=${DEVICE} ba"
},
{
"path": "scripts/gpt2/gkd/gkd_base.sh",
"chars": 2343,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/gkd/gkd_large.sh",
"chars": 2349,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/gkd/gkd_medium.sh",
"chars": 2351,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/imitkd/imitkd_base.sh",
"chars": 2424,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/imitkd/imitkd_large.sh",
"chars": 2426,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/imitkd/imitkd_medium.sh",
"chars": 2428,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/init/init_base.sh",
"chars": 2056,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/init/init_large.sh",
"chars": 2062,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/init/init_medium.sh",
"chars": 2064,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/kd/kd_base.sh",
"chars": 2300,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/kd/kd_large.sh",
"chars": 2320,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/kd/kd_medium.sh",
"chars": 2323,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/minillm/train_base_xl.sh",
"chars": 2516,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/minillm/train_large_xl.sh",
"chars": 2517,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/minillm/train_medium_xl.sh",
"chars": 2519,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/seqkd/seqkd_base.sh",
"chars": 2359,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/seqkd/seqkd_large.sh",
"chars": 2325,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/seqkd/seqkd_medium.sh",
"chars": 2328,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/sft/sft_base.sh",
"chars": 2057,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/sft/sft_large.sh",
"chars": 2063,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/sft/sft_medium.sh",
"chars": 2065,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/sft/sft_xlarge.sh",
"chars": 2112,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/tools/generate_data_seqkd.sh",
"chars": 1568,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/gpt2/tools/process_data_dolly.sh",
"chars": 824,
"preview": "BASE_PATH=${1}\n\nexport TF_CPP_MIN_LOG_LEVEL=3\n\n# only prompt for MiniLLM train\nPYTHONPATH=${BASE_PATH} python3 ${BASE_PA"
},
{
"path": "scripts/gpt2/tools/process_data_pretrain.sh",
"chars": 412,
"preview": "BASE_PATH=${1}\n\nMAX_LENGTH=512\n\nPYTHONPATH=${BASE_PATH} python3 ${BASE_PATH}/tools/process_data_pretrain.py \\\n --data"
},
{
"path": "scripts/gpt2/tools/process_pseudo_data_seqkd.sh",
"chars": 740,
"preview": "BASE_PATH=${1}\n\nexport TF_CPP_MIN_LOG_LEVEL=3\n\nPYTHONPATH=${BASE_PATH} python3 ${BASE_PATH}/tools/process_data_dolly.py "
},
{
"path": "scripts/openllama2/distillm/train_3B_7B_teacher_lora.sh",
"chars": 2970,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/openllama2/eval/eval_main_dolly_lora.sh",
"chars": 1980,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
},
{
"path": "scripts/openllama2/eval/eval_main_self_inst_lora.sh",
"chars": 1988,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
},
{
"path": "scripts/openllama2/eval/eval_main_sinst_lora.sh",
"chars": 2010,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
},
{
"path": "scripts/openllama2/eval/eval_main_uinst_lora.sh",
"chars": 2013,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
},
{
"path": "scripts/openllama2/eval/eval_main_vicuna_lora.sh",
"chars": 1982,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
},
{
"path": "scripts/openllama2/eval/run_eval.sh",
"chars": 935,
"preview": "#!/bin/bash\n\nMASTER_PORT=2040\nDEVICE=${1}\nckpt=${2}\n\n# dolly eval\nfor seed in 10 20 30 40 50\ndo\n CUDA_VISIBLE_DEVICES"
},
{
"path": "scripts/openllama2/gkd/gkd_3B_7B_teacher_lora.sh",
"chars": 2937,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/openllama2/imitkd/imitkd_3B_7B_teacher_lora.sh",
"chars": 2943,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/openllama2/init/sft_3B_lora.sh",
"chars": 2164,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/openllama2/kd/kd_3B_7B_teacher_lora.sh",
"chars": 2657,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/openllama2/minillm/train_3B_7B_lora.sh",
"chars": 3026,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/openllama2/seqkd/seqkd_3B_7B_teacher_lora.sh",
"chars": 2658,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/openllama2/sft/sft_3B_lora.sh",
"chars": 2149,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/openllama2/sft/sft_7B_lora.sh",
"chars": 2149,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/openllama2/tools/generate_data_seqkd.sh",
"chars": 1804,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/openllama2/tools/process_data_dolly.sh",
"chars": 842,
"preview": "BASE_PATH=${1}\n\nexport TF_CPP_MIN_LOG_LEVEL=3\n\n# only prompt for MiniLLM train\nPYTHONPATH=${BASE_PATH} python3 ${BASE_PA"
},
{
"path": "scripts/openllama2/tools/process_data_pretrain.sh",
"chars": 421,
"preview": "BASE_PATH=${1}\n\nMAX_LENGTH=512\n\nPYTHONPATH=${BASE_PATH} python3 ${BASE_PATH}/tools/process_data_pretrain.py \\\n --data"
},
{
"path": "scripts/openllama2/tools/process_pseudo_data_seqkd.sh",
"chars": 789,
"preview": "BASE_PATH=${1}\n\nexport TF_CPP_MIN_LOG_LEVEL=3\n\nPYTHONPATH=${BASE_PATH} python3 ${BASE_PATH}/tools/process_data_dolly.py "
},
{
"path": "scripts/opt/distillm/train_0.1B_2.7B.sh",
"chars": 2627,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/distillm/train_0.3B_2.7B.sh",
"chars": 2627,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/distillm/train_1.3B_2.7B.sh",
"chars": 2627,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/eval/eval_main_dolly.sh",
"chars": 1717,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
},
{
"path": "scripts/opt/eval/eval_main_self_inst.sh",
"chars": 1725,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
},
{
"path": "scripts/opt/eval/eval_main_sinst.sh",
"chars": 1751,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
},
{
"path": "scripts/opt/eval/eval_main_uinst.sh",
"chars": 1750,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
},
{
"path": "scripts/opt/eval/eval_main_vicuna.sh",
"chars": 1719,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
},
{
"path": "scripts/opt/eval/run_eval.sh",
"chars": 801,
"preview": "#!/bin/bash\n\nMASTER_PORT=2040\nDEVICE=${1}\nckpt=${2}\n\n# dolly eval\nfor seed in $SEED\ndo\n CUDA_VISIBLE_DEVICES=${DEVICE"
},
{
"path": "scripts/opt/gkd/gkd_0.1B_2.7B.sh",
"chars": 2476,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/gkd/gkd_0.3B_2.7B.sh",
"chars": 2477,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/gkd/gkd_1.3B_2.7B.sh",
"chars": 2477,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/imitkd/imitkd_0.1B_2.7B.sh",
"chars": 2479,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/imitkd/imitkd_0.3B_2.7B.sh",
"chars": 2480,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/imitkd/imitkd_1.3B_2.7B.sh",
"chars": 2480,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/init/init_0.1B.sh",
"chars": 2101,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/init/init_0.3B.sh",
"chars": 2065,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/init/init_1.3B.sh",
"chars": 2100,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/kd/kd_0.1B_2.7B.sh",
"chars": 2403,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/kd/kd_0.3B_2.7B.sh",
"chars": 2403,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/kd/kd_1.3B_2.7B.sh",
"chars": 2403,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/minillm/train_0.1B_2.7B.sh",
"chars": 2528,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/minillm/train_0.3B_2.7B.sh",
"chars": 2536,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/minillm/train_1.3B_2.7B.sh",
"chars": 2536,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/seqkd/seqkd_0.1B_2.7B.sh",
"chars": 2413,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/seqkd/seqkd_0.3B_2.7B.sh",
"chars": 2413,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/seqkd/seqkd_1.3B_2.7B.sh",
"chars": 2413,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/sft/sft_0.1B.sh",
"chars": 2099,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/sft/sft_0.3B.sh",
"chars": 2101,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/sft/sft_1.3B.sh",
"chars": 2143,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/sft/sft_2.7B.sh",
"chars": 2100,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/tools/generate_data_seqkd.sh",
"chars": 1595,
"preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
},
{
"path": "scripts/opt/tools/process_data_dolly.sh",
"chars": 818,
"preview": "BASE_PATH=${1}\n\nexport TF_CPP_MIN_LOG_LEVEL=3\n\n# only prompt for MiniLLM train\nPYTHONPATH=${BASE_PATH} python3 ${BASE_PA"
},
{
"path": "scripts/opt/tools/process_data_pretrain.sh",
"chars": 408,
"preview": "BASE_PATH=${1}\n\nMAX_LENGTH=512\n\nPYTHONPATH=${BASE_PATH} python3 ${BASE_PATH}/tools/process_data_pretrain.py \\\n --data"
},
{
"path": "scripts/opt/tools/process_pseudo_data_seqkd.sh",
"chars": 743,
"preview": "BASE_PATH=${1}\n\nexport TF_CPP_MIN_LOG_LEVEL=3\n\nPYTHONPATH=${BASE_PATH} python3 ${BASE_PATH}/tools/process_data_dolly.py "
},
{
"path": "tools/convert_mp.py",
"chars": 4251,
"preview": "#coding:utf-8\nimport torch\nimport argparse\nimport os\nfrom transformers import AutoModelForCausalLM\nfrom transformers imp"
},
{
"path": "tools/get_openwebtext.py",
"chars": 343,
"preview": "import datasets\nimport os\nimport re\n\ndataset = datasets.load_dataset('openwebtext', split='train')\n\nos.makedirs(\"data/op"
},
{
"path": "tools/process_data_dolly.py",
"chars": 6422,
"preview": "import multiprocessing\nimport os\nimport time\nimport torch\nimport json\nimport sys\nfrom numerize.numerize import numerize\n"
},
{
"path": "tools/process_data_pretrain.py",
"chars": 3784,
"preview": "import multiprocessing\nimport os\nimport time\nimport torch\nimport sys\nfrom numerize.numerize import numerize\nimport numpy"
},
{
"path": "train_minillm.py",
"chars": 2608,
"preview": "import torch\nimport os\nimport json\nimport torch.distributed as dist\nfrom accelerate import init_empty_weights\n\nfrom tran"
},
{
"path": "utils.py",
"chars": 8111,
"preview": "from typing import Dict\nimport numpy as np\nimport os\nimport time\nimport torch.distributed as dist\nfrom torch.distributed"
}
]
About this extraction
This page contains the full source code of the jongwooko/distillm GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 130 files (390.6 KB), approximately 117.1k tokens, and a symbol index with 283 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.