Full Code of jongwooko/distillm for AI

master d47e77ff9d27 cached
130 files
390.6 KB
117.1k tokens
283 symbols
1 requests
Download .txt
Showing preview only (425K chars total). Download the full file or copy to clipboard to get everything.
Repository: jongwooko/distillm
Branch: master
Commit: d47e77ff9d27
Files: 130
Total size: 390.6 KB

Directory structure:
gitextract_j6ug111_/

├── .gitignore
├── README.md
├── arguments.py
├── configs/
│   ├── deepspeed/
│   │   ├── ds_config.json
│   │   ├── ds_config_fp32.json
│   │   ├── ds_config_zero2.json
│   │   └── ds_config_zero2_offload.json
│   └── hostfiles/
│       ├── node_0_1
│       ├── node_0_1_2_3
│       ├── node_1_2
│       └── node_2_3
├── data_utils/
│   ├── distributed_indexed.py
│   ├── indexed_dataset.py
│   ├── lm_datasets.py
│   └── prompt_datasets.py
├── distillm/
│   ├── __init__.py
│   ├── buffer.py
│   ├── losses.py
│   └── sampler.py
├── evaluate.py
├── evaluate_main.py
├── finetune.py
├── generate.py
├── install.sh
├── minillm/
│   ├── __init__.py
│   ├── data_types.py
│   ├── losses.py
│   ├── model.py
│   ├── pipelines.py
│   ├── reward.py
│   ├── sampler.py
│   ├── storages.py
│   ├── trainer.py
│   └── utils.py
├── rouge_metric.py
├── scripts/
│   ├── gpt2/
│   │   ├── distillm/
│   │   │   ├── train_0.1B_1.5B.sh
│   │   │   ├── train_0.3B_1.5B.sh
│   │   │   └── train_0.7B_1.5B.sh
│   │   ├── eval/
│   │   │   ├── eval_main_dolly.sh
│   │   │   ├── eval_main_self_inst.sh
│   │   │   ├── eval_main_sinst.sh
│   │   │   ├── eval_main_uinst.sh
│   │   │   ├── eval_main_vicuna.sh
│   │   │   └── run_eval.sh
│   │   ├── gkd/
│   │   │   ├── gkd_base.sh
│   │   │   ├── gkd_large.sh
│   │   │   └── gkd_medium.sh
│   │   ├── imitkd/
│   │   │   ├── imitkd_base.sh
│   │   │   ├── imitkd_large.sh
│   │   │   └── imitkd_medium.sh
│   │   ├── init/
│   │   │   ├── init_base.sh
│   │   │   ├── init_large.sh
│   │   │   └── init_medium.sh
│   │   ├── kd/
│   │   │   ├── kd_base.sh
│   │   │   ├── kd_large.sh
│   │   │   └── kd_medium.sh
│   │   ├── minillm/
│   │   │   ├── train_base_xl.sh
│   │   │   ├── train_large_xl.sh
│   │   │   └── train_medium_xl.sh
│   │   ├── seqkd/
│   │   │   ├── seqkd_base.sh
│   │   │   ├── seqkd_large.sh
│   │   │   └── seqkd_medium.sh
│   │   ├── sft/
│   │   │   ├── sft_base.sh
│   │   │   ├── sft_large.sh
│   │   │   ├── sft_medium.sh
│   │   │   └── sft_xlarge.sh
│   │   └── tools/
│   │       ├── generate_data_seqkd.sh
│   │       ├── process_data_dolly.sh
│   │       ├── process_data_pretrain.sh
│   │       └── process_pseudo_data_seqkd.sh
│   ├── openllama2/
│   │   ├── distillm/
│   │   │   └── train_3B_7B_teacher_lora.sh
│   │   ├── eval/
│   │   │   ├── eval_main_dolly_lora.sh
│   │   │   ├── eval_main_self_inst_lora.sh
│   │   │   ├── eval_main_sinst_lora.sh
│   │   │   ├── eval_main_uinst_lora.sh
│   │   │   ├── eval_main_vicuna_lora.sh
│   │   │   └── run_eval.sh
│   │   ├── gkd/
│   │   │   └── gkd_3B_7B_teacher_lora.sh
│   │   ├── imitkd/
│   │   │   └── imitkd_3B_7B_teacher_lora.sh
│   │   ├── init/
│   │   │   └── sft_3B_lora.sh
│   │   ├── kd/
│   │   │   └── kd_3B_7B_teacher_lora.sh
│   │   ├── minillm/
│   │   │   └── train_3B_7B_lora.sh
│   │   ├── seqkd/
│   │   │   └── seqkd_3B_7B_teacher_lora.sh
│   │   ├── sft/
│   │   │   ├── sft_3B_lora.sh
│   │   │   └── sft_7B_lora.sh
│   │   └── tools/
│   │       ├── generate_data_seqkd.sh
│   │       ├── process_data_dolly.sh
│   │       ├── process_data_pretrain.sh
│   │       └── process_pseudo_data_seqkd.sh
│   └── opt/
│       ├── distillm/
│       │   ├── train_0.1B_2.7B.sh
│       │   ├── train_0.3B_2.7B.sh
│       │   └── train_1.3B_2.7B.sh
│       ├── eval/
│       │   ├── eval_main_dolly.sh
│       │   ├── eval_main_self_inst.sh
│       │   ├── eval_main_sinst.sh
│       │   ├── eval_main_uinst.sh
│       │   ├── eval_main_vicuna.sh
│       │   └── run_eval.sh
│       ├── gkd/
│       │   ├── gkd_0.1B_2.7B.sh
│       │   ├── gkd_0.3B_2.7B.sh
│       │   └── gkd_1.3B_2.7B.sh
│       ├── imitkd/
│       │   ├── imitkd_0.1B_2.7B.sh
│       │   ├── imitkd_0.3B_2.7B.sh
│       │   └── imitkd_1.3B_2.7B.sh
│       ├── init/
│       │   ├── init_0.1B.sh
│       │   ├── init_0.3B.sh
│       │   └── init_1.3B.sh
│       ├── kd/
│       │   ├── kd_0.1B_2.7B.sh
│       │   ├── kd_0.3B_2.7B.sh
│       │   └── kd_1.3B_2.7B.sh
│       ├── minillm/
│       │   ├── train_0.1B_2.7B.sh
│       │   ├── train_0.3B_2.7B.sh
│       │   └── train_1.3B_2.7B.sh
│       ├── seqkd/
│       │   ├── seqkd_0.1B_2.7B.sh
│       │   ├── seqkd_0.3B_2.7B.sh
│       │   └── seqkd_1.3B_2.7B.sh
│       ├── sft/
│       │   ├── sft_0.1B.sh
│       │   ├── sft_0.3B.sh
│       │   ├── sft_1.3B.sh
│       │   └── sft_2.7B.sh
│       └── tools/
│           ├── generate_data_seqkd.sh
│           ├── process_data_dolly.sh
│           ├── process_data_pretrain.sh
│           └── process_pseudo_data_seqkd.sh
├── tools/
│   ├── convert_mp.py
│   ├── get_openwebtext.py
│   ├── process_data_dolly.py
│   └── process_data_pretrain.py
├── train_minillm.py
└── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks
# Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks

### JupyterNotebooks ###
# gitignore template for Jupyter Notebooks
# website: http://jupyter.org/

.ipynb_checkpoints
*/.ipynb_checkpoints/*

# IPython
profile_default/
ipython_config.py

# Remove previous ipynb_checkpoints
#   git rm -r .ipynb_checkpoints/

### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook

# IPython

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
poetry.toml

# ruff
.ruff_cache/

# LSP config files
pyrightconfig.json

# End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks

./results/
./data/
./checkpoints/
./processed_data/
*.tar.gz


================================================
FILE: README.md
================================================
# DistiLLM: Towards Streamlined Distillation for Large Language Models (ICML 2024)

<a href="https://arxiv.org/abs/2402.03898"><img src="https://img.shields.io/badge/Paper-arXiv:2402.03898-Green"></a>
<a href=#bibtex><img src="https://img.shields.io/badge/Paper-BibTex-yellow"></a>

Official PyTorch implementation of **DistiLLM**, as presented in our paper: \
\
**DistiLLM: Towards Streamlined Distillation for Large Language Models** \
*[Jongwoo Ko](https://sites.google.com/view/jongwooko), [Sungnyun Kim](https://sungnyunkim.notion.site/Sungnyun-Kim-4770a0182c47469ebdcd357cde97bd32), Tianyi Chen, Se-Young Yun* \
KAIST AI and Microsoft

## 🚀 Updates
- [x] (25.03.11) DistiLLM-2 paper is out! The preliminary code will be available in this repo, and final code will be available in [here](https://github.com/jongwooko/distillm-2), soon.
- [x] (24.08.12) Remove the dependency on the local transformers, which are outdated. You can work with various types of recent LLMs!
- [x] (24.05.01) Our paper has been accepted in **ICML 2024**. We are open to receiving any discussions and will reflect them in the camera-ready version. Looking forward to seeing you in Vienna!
- [x] (24.03.13) Release [**LoRA checkpoints for OpenLLaMa2-3B**](https://drive.google.com/drive/folders/1Yun1aNpn-mz2h-IVH_VdJ1Jhzm0K55Bo?usp=sharing)

## Environment
```bash
bash install.sh
```

Our code is based on [this commit](https://github.com/huggingface/transformers/commit/85fde09c97213bf7e8625f83096bb2a9e183f987) of HuggingFace Transformers **by following MiniLLM**.

## Data
### Resources
+ The training/evaluation intruction-response data before processing can be downloaded from this [link](https://conversationhub.blob.core.windows.net/beit-share-public/MiniLLM/data.tar?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D).
+ The plain-text corpus $\mathcal{D}_\text{PT}$ can be download from the HugginFace datasets [repository](https://huggingface.co/datasets/openwebtext).


### Data Processing
Get plain-text corpus $\mathcal{D}_\text{PT}$:
```bash
python3 tools/get_openwebtext.py
```
This script will replace the continuous `\n` in each document with a special token "<@x(x!>" and write each document in OpenWebText in a line, which is convenient for parallel processing. In `data/openwebtext/data.txt`, we give an example of the resulting format. You can follow this format to prepare other corpus beyond OpenWebText.

Tokenize the data and store them in binary files:
```bash
bash scripts/gpt2/tools/process_data_dolly.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM} # Process Dolly Train / Validation Data
bash scripts/gpt2/tools/process_data_pretrain.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM} # Process OpenWebText Train / Validation Data

bash scripts/opt/tools/process_data_dolly.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM} # Process Dolly Train / Validation Data
bash scripts/opt/tools/process_data_pretrain.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM} # Process OpenWebText Corpus Train / Validation Data

bash scripts/llama/tools/process_data_dolly.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM} # Process Dolly Train / Validation Data
bash scripts/llama/tools/process_data_pretrain.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM} # Process OpenWebText Corpus Train / Validation Data
```

## Base Pre-trained Models
To run fine-tuning or standard KD baselines, you need to download the model checkpoints from [Huggingface Model Hub] and put them in `checkpoints/`. For example, for gpt2-large, you can download the model from this [link](https://huggingface.co/gpt2-large/tree/main) and put them in `checkpoints/gpt2-large`.

Alternatively, you can also change the `CKPT` variable in each script to the corresponding model name to enable Transformers to download the base models automatically. For example, set `CKPT="gpt2-large"` in `scripts/gpt2/sft/sft_large.sh` causes download of the gpt2-large base model from the HugginFace model hub.

## Train
We provide example commands for GPT-2 models. Similar scripts for model families can be found in `scripts/opt` and `scripts/openllama2`. All our experiments are conducted on 4 \* 40A100, which can be reduced for small models.

### Baselines
The final checkpoints are selected by the **ROUGE-L** scores.

#### Fine-tune the teacher models
```bash
bash scripts/gpt2/sft/sft_xlarge.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
```
#### SFT Baselines
```bash
bash scripts/gpt2/sft/sft_base.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/sft/sft_medium.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/sft/sft_large.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
```

#### KD Baselines
```bash
bash scripts/gpt2/kd/kd_base.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/kd/kd_medium.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/kd/kd_large.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
```

#### SeqKD Baselines
Generate and process responses with the teacher:
```bash
bash scripts/gpt2/tools/generate_data_seqkd.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/tools/process_pseudo_data_seqkd.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
```
Fine-tune the model with SeqKD:
```bash
bash scripts/gpt2/seqkd/seqkd_base.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/seqkd/seqkd_medium.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/seqkd/seqkd_large.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
```

#### Student Initialization
The final checkpoints are selected by the **validation loss**.
```bash
bash scripts/gpt2/init/init_base.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/init/init_medium.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/init/init_large.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
```

#### ImitKD Baselines
```bash
bash scripts/gpt2/imitkd/imitkd_base_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/imitkd/imitkd_medium_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/imitkd/imitkd_large_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
```

#### MiniLLM Baselines
```bash
bash scripts/gpt2/minillm/train_base_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/minillm/train_medium_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/minillm/train_large_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
```

#### GKD Baselines
```bash
bash scripts/gpt2/gkd/gkd_base_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/gkd/gkd_medium_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/gkd/gkd_large_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
```

### DistiLLM
The final checkpoints are selected by the **validation loss**.
```bash
bash scripts/gpt2/init/init_base.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/init/init_medium.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/init/init_large.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
```

The final checkpoints are selected by the **ROUGE-L** scores.
```bash
bash scripts/gpt2/distillm/train_base_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/distillm/train_medium_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/distillm/train_large_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
```

## Run Evaluation
```bash
bash scripts/gpt2/eval/run_eval.sh ${GPU_IDX} ${/PATH/TO/DistiLLM}
bash scripts/opt/eval/run_eval.sh ${GPU_IDX} ${/PATH/TO/DistiLLM} 
bash scripts/openllama2/eval/run_eval.sh ${GPU_IDX} ${/PATH/TO/DistiLLM} 
```

## Results
DistiLLM outperforms other KD baselines in terms of both generation performance and training speed for various model families such as GPT-2, OPT, and OpenLLaMA.
<p align="center">
<img width="1394" src="https://github.com/jongwooko/distillm/assets/59277369/19ddac5c-4cd6-4d81-99d8-32723a8e60d8">
</p>

## Checkpoints (OpenLLaMA-3B)
We share the LoRA weights for OpenLLaMA-3B in [google drive](https://drive.google.com/drive/folders/1Yun1aNpn-mz2h-IVH_VdJ1Jhzm0K55Bo?usp=sharing).

## Acknowledgement
Our code is based on the code of ICLR2024 [MiniLLM: Knowledge Distillation of Large Language Models](https://arxiv.org/pdf/2306.08543.pdf).

## Star History

[![Star History Chart](https://api.star-history.com/svg?repos=jongwooko/distillm&type=Date)](https://star-history.com/#jongwooko/distillm&Date)

## BibTeX
If you find this repo useful for your research, please consider citing our paper:

```
@inproceedings{kodistillm,
  title={DistiLLM: Towards Streamlined Distillation for Large Language Models},
  author={Ko, Jongwoo and Kim, Sungnyun and Chen, Tianyi and Yun, Se-Young},
  booktitle={Forty-first International Conference on Machine Learning}
}

@article{ko2025distillm2,
      title={DistiLLM-2: A Contrastive Approach Boosts the Distillation of LLMs}, 
      author={Jongwoo Ko and Tianyi Chen and Sungnyun Kim and Tianyu Ding and Luming Liang and Ilya Zharkov and Se-Young Yun},
      year={2025},
      journal={arXiv preprint arXiv:2503.07067}
}
```

## Contact
- Jongwoo Ko: jongwoo.ko@kaist.ac.kr


================================================
FILE: arguments.py
================================================
# coding=utf-8
# Copyright 2020 The OpenBMB team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import deepspeed
import numpy as np


def add_model_args(parser: argparse.ArgumentParser):
    """Model arguments"""

    group = parser.add_argument_group('model', 'model configuration')
    group.add_argument('--model-path', type=str, help='model path')
    group.add_argument("--ckpt-name", type=str)
    group.add_argument("--model-type", type=str, default="gpt2")
    group.add_argument("--teacher-model-type", type=str, default=None)
    group.add_argument("--n-gpu", type=int, default=1)
    group.add_argument("--n-nodes", type=int, default=1)
    group.add_argument("--teacher-model-path", type=str)
    group.add_argument("--teacher-ckpt-name", type=str)
    group.add_argument("--teacher-model-fp16", action="store_true")
    group.add_argument("--model-parallel", action="store_true")
    group.add_argument("--model-parallel-size", type=int, default=None)
    group.add_argument("--no-value", action="store_true")
    group.add_argument("--dropout-path-rate", type=float, default=None)
    group.add_argument("--fp32", action="store_true")
    return parser


def add_runtime_args(parser: argparse.ArgumentParser):
    group = parser.add_argument_group('runtime', 'runtime configurations')

    group.add_argument("--type", type=str, default=None)
    group.add_argument("--do-train", action="store_true")
    group.add_argument("--do-valid", action="store_true")
    group.add_argument("--do-eval", action="store_true")
    group.add_argument('--base-path', type=str, default=None, help='Path to the project base directory.')
    group.add_argument('--load', type=str, default=None,
                       help='Path to a directory containing a model checkpoint.')
    group.add_argument('--save', type=str, default=None,
                       help='Output directory to save checkpoints to.')
    group.add_argument("--log-interval", type=int, default=10)
    group.add_argument("--mid-log-num", type=int, default=4)
    group.add_argument('--save-interval', type=int, default=1000,
                       help='number of iterations between saves')
    group.add_argument("--eval-interval", type=int, default=1000)
    group.add_argument('--local_rank', type=int, default=None,
                       help='local rank passed from distributed launcher')
    group.add_argument("--save-additional-suffix", type=str, default="")
    group.add_argument("--save-rollout", action="store_true")
    group.add_argument("--eb-sample-times", type=int, default=3)
    return parser


def add_data_args(parser: argparse.ArgumentParser):
    group = parser.add_argument_group('data', 'data configurations')
    group.add_argument("--data-dir", type=str, default=None)
    group.add_argument("--processed-data-dir", type=str, default=None)
    group.add_argument("--force-process", action="store_true")
    group.add_argument("--force-process-demo", action="store_true")
    group.add_argument("--data-process-workers", type=int, default=-1)
    group.add_argument("--train-num", type=int, default=-1)
    group.add_argument("--train-ratio", type=float, default=1)
    group.add_argument("--dev-num", type=int, default=-1)
    group.add_argument("--dev-ratio", type=float, default=1)
    group.add_argument("--gen-num", type=int, default=-1)
    group.add_argument("--data-names", type=str, default=None)
    group.add_argument("--prompt-type", type=str, default=None)
    group.add_argument("--num-workers", type=int, default=1)
    group.add_argument("--max-prompt-length", type=int, default=512)
    group.add_argument("--min-prompt-length", type=int, default=128)
    group.add_argument("--json-data", action="store_true")
    group.add_argument("--bin-data", action="store_true")
    group.add_argument("--txt-data", action="store_true")
    
    group.add_argument("--prompt-data-dir", type=str)
    group.add_argument("--lm-data-dir", type=str)
    group.add_argument("--eval-ppl", action="store_true")
    group.add_argument("--eval-rw", action="store_true")
    group.add_argument("--eval-gen", action="store_true")
    
    group.add_argument("--only-prompt", action="store_true")
    return parser


def add_hp_args(parser: argparse.ArgumentParser):
    group = parser.add_argument_group("hp", "hyper parameter configurations")
    group.add_argument('--batch-size', type=int, default=32,
                       help='Data Loader batch size')
    group.add_argument('--eval-batch-size', type=int, default=32,
                       help='Data Loader batch size')
    group.add_argument('--clip-grad', type=float, default=1.0,
                       help='gradient clipping')
    group.add_argument('--total-iters', type=int, default=None,
                       help='total number of iterations')
    group.add_argument('--train-iters-per-epoch', type=int, default=-1,
                       help='total number of iterations per epoch')
    group.add_argument('--max-length', type=int, default=1024,
                       help='max length of input')
    group.add_argument('--seed', type=int, default=1234,
                       help='random seed for reproducibility')
    group.add_argument("--seed-order", type=int, default=42)
    group.add_argument("--seed-data", type=int, default=42)
    group.add_argument("--seed-ppo", type=int, default=42)
    group.add_argument("--seed-lm", type=int, default=7)
    group.add_argument('--epochs', type=int, default=None,
                       help='total number of epochs to train over all training runs')
    group.add_argument('--training-epochs', type=int, default=10000)
    group.add_argument("--gradient-accumulation-steps", type=int, default=1)
    group.add_argument("--gradient-checkpointing", action="store_true")
    group.add_argument("--attn-dtype", default=None)
    
    group.add_argument('--lr', type=float, help='initial learning rate')
    group.add_argument("--lr-min", type=float, default=0.0000001)
    group.add_argument('--weight-decay', type=float, default=1.0e-2,
                       help='weight-decay')
    group.add_argument('--loss-scale', type=float, default=65536,
                       help='loss scale')
    group.add_argument("--kd-ratio", type=float, default=None)

    group.add_argument('--warmup-iters', type=int, default=0,
                       help='percentage of data to warmup on (.01 = 1% of all '
                       'training iters). Default 0.01')
    group.add_argument('--lr-decay-iters', type=int, default=None,
                       help='number of iterations to decay LR over,'
                       ' If None defaults to `--train-iters`*`--epochs`')
    group.add_argument('--lr-decay-style', type=str, default='noam',
                       choices=['constant', 'linear', 'cosine', 'exponential', 'noam'],
                       help='learning rate decay function')
    group.add_argument("--scheduler-name", type=str, default="constant_trm")

    return parser


def add_ppo_args(parser: argparse.ArgumentParser):
    group = parser.add_argument_group('ppo', 'ppo configurations')
    
    group.add_argument("--reward-scaling", type=float, default=None)
    group.add_argument("--cliprange-reward", type=float, default=1)
    group.add_argument("--ppo-epochs", type=int, default=None)
    group.add_argument("--num-rollouts", type=int, default=256)
    group.add_argument("--num-rollouts-per-device", type=int, default=None)
    group.add_argument("--cliprange", type=float, default=0.2)
    group.add_argument("--chunk-size", type=int, default=None)
    group.add_argument("--gamma", type=float, default=0.95)
    
    return parser


def add_minillm_args(parser: argparse.ArgumentParser):
    group = parser.add_argument_group('minillm', 'minillm configurations')
    
    group.add_argument("--length-norm", action="store_true")
    group.add_argument("--single-step-reg", action="store_true")
    group.add_argument("--teacher-mixed-alpha", type=float, default=None)
    group.add_argument("--lm-coef", type=float, default=1)
    
    return parser


def add_distillm_args(parser: argparse.ArgumentParser):
    group = parser.add_argument_group('distillm', 'distillm configurations')

    # skew kld
    group.add_argument("--skew-alpha", type=float, default=0.1)
    
    # student generation
    group.add_argument("--student-gen", action="store_true")
    group.add_argument("--gen-top-p", type=float, default=1.0)
    group.add_argument("--gen-num-beams", type=int, default=2)
    
    # adaptive threshold
    group.add_argument("--mixed-alpha", type=float, default=0.5)
    group.add_argument("--loss-eps", type=float, default=0.1)
    group.add_argument("--init-threshold", type=float, default=0.0)
    
    # off-policy
    group.add_argument("--capacity", type=int, default=1000)
    group.add_argument("--replay-ratio", type=str, default="decreasing")
    # group.add_argument("--time", action="store_true")
    return parser


def add_gen_args(parser: argparse.ArgumentParser):
    group = parser.add_argument_group('generation', 'generation configurations')
    
    group.add_argument("--top-k", type=int, default=0)
    group.add_argument("--top-p", type=float, default=1.0)
    group.add_argument("--do-sample", action="store_true")
    group.add_argument("--no-repeat-ngram-size", type=int, default=6)
    group.add_argument("--repetition-penalty", type=float, default=None)
    group.add_argument("--num-beams", type=int, default=1)
    group.add_argument("--temperature", type=float, default=1)
    
    return parser


def add_peft_args(parser: argparse.ArgumentParser):
    group = parser.add_argument_group('generation', 'generation configurations')
    
    group.add_argument("--peft", type=str, default=None)
    group.add_argument("--peft-lora-r", type=int, default=16)
    group.add_argument("--peft-lora-alpha", type=int, default=64)
    group.add_argument("--peft-lora-dropout", type=float, default=0.1)
    group.add_argument("--peft-name", type=str, default=None)
    group.add_argument("--peft-path", type=str, default=None)
    group.add_argument("--teacher-peft-name", type=str, default=None)
    group.add_argument("--teacher-peft-path", type=str, default=None)
    return parser


def get_args():
    parser = argparse.ArgumentParser()
    parser = add_model_args(parser)
    parser = add_runtime_args(parser)
    parser = add_data_args(parser)
    parser = add_hp_args(parser)
    parser = add_ppo_args(parser)
    parser = add_minillm_args(parser)
    parser = add_distillm_args(parser)
    parser = add_gen_args(parser)
    parser = add_peft_args(parser)
    parser = deepspeed.add_config_arguments(parser)
    
    args, unknown = parser.parse_known_args()
    
    assert all(["--" not in x for x in unknown]), unknown
    
    args.local_rank = int(os.getenv("LOCAL_RANK", "0"))
        
    args.n_gpu = args.n_gpu * args.n_nodes
        
    if args.type == "eval_main":
        ckpt_name = None
        if args.ckpt_name is not None:
            ckpt_name = args.ckpt_name
        if args.peft_name is not None:
            ckpt_name = args.peft_name

        if ckpt_name is not None:
            tmp = ckpt_name.split("/")
            if tmp[-1].isdigit():
                ckpt_name = "_".join(tmp[:-1]) + "/" + tmp[-1]
            else:
                ckpt_name = "_".join(tmp)

        save_path = os.path.join(
            args.save,
            f"{args.data_names}-{args.max_length}" + (f"-mp{args.model_parallel_size}" if args.model_parallel > 0 else ""),
            ckpt_name,
            f"{args.seed}",
        )
        args.save = save_path
    elif args.type == "lm":
        save_path = os.path.join(
            args.save,
            (f"{args.ckpt_name}" + f"-{args.peft_name}" if args.peft_name is not None else ""),
            (f"e{args.epochs}-bs{args.batch_size}-lr{args.lr}-G{args.gradient_accumulation_steps}-N{args.n_gpu}-NN{args.n_nodes}") + \
            (f"-mp{args.model_parallel_size}" if args.model_parallel > 0 else "") + \
            (f"-lora-{args.peft_lora_r}-{args.peft_lora_alpha}-{args.peft_lora_dropout}" if args.peft == "lora" else "") + \
            args.save_additional_suffix
        )
        args.save = save_path
    elif args.type == "kd":
        save_path = os.path.join(
            args.save,
            (f"{args.ckpt_name}" + f"-{args.peft_name}" if args.peft_name is not None else "" + \
             f"-{args.teacher_ckpt_name}" + f"-{args.teacher_peft_name}" if args.teacher_peft_name is not None else ""),
            (f"e{args.epochs}-bs{args.batch_size}-lr{args.lr}-G{args.gradient_accumulation_steps}-N{args.n_gpu}-NN{args.n_nodes}-kd{args.kd_ratio}") + \
            (f"-mp{args.model_parallel_size}" if args.model_parallel > 0 else "") + \
            (f"-lora-{args.peft_lora_r}-{args.peft_lora_alpha}-{args.peft_lora_dropout}" if args.peft == "lora" else "") + \
            args.save_additional_suffix
        )
        args.save = save_path
    elif args.type == "gen":
        save_path = os.path.join(
            args.save,
            (f"{args.ckpt_name}"),
            (f"t{args.temperature}-l{args.max_length}"),
        )
        args.save = save_path
    elif args.type == "minillm":
        ppo_prefix = f"pe{args.ppo_epochs}" + \
                     (f"_rs{args.reward_scaling}" if args.ppo_epochs is not None else "") + \
                     (f"_nr{args.num_rollouts}" if args.num_rollouts is not None else "") + \
                     (f"_ln" if args.length_norm else "") + \
                     (f"_sr" if args.single_step_reg else "") + \
                     (f"_tm{args.teacher_mixed_alpha}" if args.teacher_mixed_alpha is not None else "")
        save_path = os.path.join(
            args.save,
            (f"{args.ckpt_name}" + f"-{args.peft_name}" if args.peft_name is not None else "" + \
             f"-{args.teacher_ckpt_name}" + f"-{args.teacher_peft_name}" if args.teacher_peft_name is not None else ""),
            (f"bs{args.batch_size}-lr{args.lr}-G{args.gradient_accumulation_steps}-N{args.n_gpu}-NN{args.n_nodes}-lm{args.lm_coef}-len{args.max_length}" + \
                (f"-mp{args.model_parallel_size}" if args.model_parallel > 0 else "")) + \
            (f"-lora-{args.peft_lora_r}-{args.peft_lora_alpha}-{args.peft_lora_dropout}" if args.peft == "lora" else ""),
            ppo_prefix + args.save_additional_suffix
        )
        args.save = save_path
        args.num_rollouts_per_device = args.num_rollouts // args.n_gpu
        
        if args.warmup_iters > 0:
            assert args.scheduler_name is not None

    return args


================================================
FILE: configs/deepspeed/ds_config.json
================================================
{
    "train_micro_batch_size_per_gpu": 1,
    "gradient_accumulation_steps": 1,
    "zero_optimization": {
        "stage": 1
    },
    "zero_allow_untested_optimizer": true,
    "fp16": {
        "enabled": true,
        "loss_scale": 0,
        "initial_scale_power": 11,
        "loss_scale_window": 2000,
        "hysteresis": 4
    },
    "wall_clock_breakdown": false
}

================================================
FILE: configs/deepspeed/ds_config_fp32.json
================================================
{
    "train_micro_batch_size_per_gpu": 1,
    "gradient_accumulation_steps": 1,
    "zero_optimization": {
        "stage": 1
    },
    "zero_allow_untested_optimizer": true,
    "fp16": {
        "enabled": false
    },
    "wall_clock_breakdown": false
}

================================================
FILE: configs/deepspeed/ds_config_zero2.json
================================================
{
    "train_micro_batch_size_per_gpu": 1,
    "gradient_accumulation_steps": 1,
    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "none"
        },
        "allgather_partitions": true,
        "allgather_bucket_size": 2e8,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 2e8,
        "contiguous_gradients": true
    },
    "zero_allow_untested_optimizer": true,
    "fp16": {
        "enabled": true,
        "loss_scale": 0,
        "initial_scale_power": 11,
        "loss_scale_window": 5000,
        "hysteresis": 4
    },
    "wall_clock_breakdown": false
}

================================================
FILE: configs/deepspeed/ds_config_zero2_offload.json
================================================
{
    "train_micro_batch_size_per_gpu": 1,
    "gradient_accumulation_steps": 1,
    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "cpu"
        },
        "allgather_partitions": true,
        "allgather_bucket_size": 2e8,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 2e8,
        "contiguous_gradients": true
    },
    "zero_force_ds_cpu_optimizer": false,
    "zero_allow_untested_optimizer": true,
    "fp16": {
        "enabled": true,
        "loss_scale": 0,
        "initial_scale_power": 11,
        "loss_scale_window": 5000,
        "hysteresis": 4
    },
    "wall_clock_breakdown": false
}

================================================
FILE: configs/hostfiles/node_0_1
================================================
node-0 slots=8
node-1 slots=8

================================================
FILE: configs/hostfiles/node_0_1_2_3
================================================
node-0 slots=8
node-1 slots=8
node-2 slots=8
node-3 slots=8

================================================
FILE: configs/hostfiles/node_1_2
================================================
node-1 slots=8
node-2 slots=8

================================================
FILE: configs/hostfiles/node_2_3
================================================
node-2 slots=8
node-3 slots=8

================================================
FILE: data_utils/distributed_indexed.py
================================================
# coding=utf-8
# Copyright 2020 The OpenBMB team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import struct
import shutil

from itertools import accumulate

import numpy as np
import torch
import torch.distributed as dist
from utils import print_rank, save_rank


dtypes = {
    1: np.uint8,
    2: np.int8,
    3: np.int16,
    4: np.int32,
    5: np.int64,
    6: np.float32,
    7: np.double,
    8: np.uint16,
    9: np.uint32
}


def code(dtype):
    for k in dtypes.keys():
        if dtypes[k] == dtype:
            return k
    raise ValueError(dtype)


def index_file_path(prefix_path):
    return prefix_path + '.idx'


def data_file_path(prefix_path):
    return prefix_path + '.bin'


class DistributedMMapIndexedDataset(torch.utils.data.Dataset):
    class Index(object):
        _HDR_MAGIC = b'MMIDIDX\x00\x00'
        def __init__(self, path):
            with open(path, 'rb') as stream:
                magic_test = stream.read(9)
                assert self._HDR_MAGIC == magic_test, (
                    'Index file doesn\'t match expected format. '
                    'Make sure that --dataset-impl is configured properly.'
                )
                version = struct.unpack('<Q', stream.read(8))
                assert (1,) == version

                dtype_code, = struct.unpack('<B', stream.read(1))
                self._dtype = dtypes[dtype_code]
                self._dtype_size = self._dtype().itemsize

                self._len = struct.unpack('<Q', stream.read(8))[0]
                self._doc_count = struct.unpack('<Q', stream.read(8))[0]
                offset = stream.tell()

            self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
            self._bin_buffer = memoryview(self._bin_buffer_mmap)
            self._sizes = np.frombuffer(
                self._bin_buffer,
                dtype=np.int32,
                count=self._len,
                offset=offset)
            self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
                                           offset=offset + self._sizes.nbytes)
            self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
                                          offset=offset + self._sizes.nbytes + self._pointers.nbytes)

        def __del__(self):
            self._bin_buffer_mmap._mmap.close()
            del self._bin_buffer_mmap

        @property
        def dtype(self):
            return self._dtype

        @property
        def sizes(self):
            return self._sizes

        @property
        def doc_idx(self):
            return self._doc_idx

        def __getitem__(self, i):
            return self._pointers[i], self._sizes[i]

        def __len__(self):
            return self._len

    def __init__(self, path, name, rank_number, rank_total, cache = None):
        
        super().__init__()

        self._path = path
        self._name = name
        self._state = 0
        if cache is not None:
            self._cache = cache
            os.makedirs(self._cache, exist_ok=True)
        else:
            self._cache = None
        self._rank_total = rank_total
        self._rank_number = rank_number
        self._index = None
        self._bin_buffer = None
        self._bin_buffer_mmap = None
        self.max_state, self.history = self._probe_data_path(self._path, self._name, self._rank_total)
        self.total_length = self.history[self.max_state-1][1]

        self._do_init(self._path, self._name, self._cache, self._state)

    def _probe_data_path(self, path, name, rank_total):
        print_rank("Probing Dataset")
            
        state = 0
        history = {-1:(0, 0)}
        for state in range(np.iinfo(np.int32).max):
            source_file = path + name + f"_{state}"
            if self.exists(source_file):
                index = self.Index(index_file_path(source_file))
                history[state] = (history[state-1][1], history[state-1][1] + len(index))
            else:
                break
            
        print_rank(f"Probing end. Max data state {state}, total length {history[state-1][1]}")
        
        return state, history

    def __getstate__(self):
        return self._path + self._name + "_%d"%(self._state)

    def __setstate__(self, state):
        self._state = state
        self._do_init(self._path, self._name, self._cache, self._state)

    def _do_init(self, path, name, cache, state):
        if self._bin_buffer_mmap is not None:
            self._bin_buffer_mmap._mmap.close()
            del self._bin_buffer_mmap
        if self._index is not None:
            del self._index

        self._state = state

        source_file = path + name + f"_{self._state}"
        self._index = self.Index(index_file_path(source_file))
        self._bin_buffer_mmap = np.memmap(data_file_path(source_file), mode='r', order='C')
        self._bin_buffer = memoryview(self._bin_buffer_mmap)

    def __del__(self):
        if self._bin_buffer_mmap is not None:
            self._bin_buffer_mmap._mmap.close()
            del self._bin_buffer_mmap
        if self._index is not None:
            del self._index

    def __len__(self):
        return self.total_length

    def _next_file(self):
        self._state += 1
        if self._state >= self.max_state:
            self._state = 0
        # print_rank(f"next_file: {self._state}")
        self._do_init(self._path, self._name, self._cache, self._state)
    
    def __relative_idx(self, idx):
        res = idx - self.history[self._state][0]
        return res

    def __slice_item(self, start, stop):
        ptr = self._index._pointers[self.__relative_idx(start)]
        sizes = self._index._sizes[self.__relative_idx(start):self.__relative_idx(stop)]
        offsets = list(accumulate(sizes))
        np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=sum(sizes), offset=ptr)
        return np.split(np_array, offsets[:-1])

    def __getitem__(self, idx):
        if isinstance(idx, int):
            while idx >= self.history[self._state][1] or idx < self.history[self._state][0]:
                self._next_file()
            ptr, size = self._index[self.__relative_idx(idx)]
            return np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)
        elif isinstance(idx, slice):
            raise NotImplementedError()

    @property
    def sizes(self):
        return self._index.sizes
        
    def exists(self, path):
        return (
            os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
        )


================================================
FILE: data_utils/indexed_dataset.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


# copied from fairseq/fairseq/data/indexed_dataset.py
# Removed IndexedRawTextDataset since it relied on Fairseq dictionary
# other slight modifications to remove fairseq dependencies
# Added document index to index file and made it accessible.
#    An empty sentence no longer separates documents.

from functools import lru_cache
import os
import shutil
import struct
from itertools import accumulate

import numpy as np
import torch


def __best_fitting_dtype(vocab_size=None):
    if vocab_size is not None and vocab_size < 65500:
        return np.uint16
    else:
        return np.int32


def get_available_dataset_impl():
    return ['lazy', 'cached', 'mmap']


def infer_dataset_impl(path):
    if IndexedDataset.exists(path):
        with open(index_file_path(path), 'rb') as f:
            magic = f.read(8)
            if magic == IndexedDataset._HDR_MAGIC:
                return 'cached'
            elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
                return 'mmap'
            else:
                return None
    else:
        print(f"Dataset does not exist: {path}")
        print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
        return None


def make_builder(out_file, impl, dtype):
    if impl == 'mmap':
        return MMapIndexedDatasetBuilder(out_file, dtype=dtype)
    else:
        return IndexedDatasetBuilder(out_file)


def make_dataset(path, impl, skip_warmup=False):
    if not IndexedDataset.exists(path):
        print(f"Dataset does not exist: {path}")
        print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
        return None
    if impl == 'infer':
        impl = infer_dataset_impl(path)
    if impl == 'lazy' and IndexedDataset.exists(path):
        return IndexedDataset(path)
    elif impl == 'cached' and IndexedDataset.exists(path):
        return IndexedCachedDataset(path)
    elif impl == 'mmap' and MMapIndexedDataset.exists(path):
        return MMapIndexedDataset(path, skip_warmup)
    print(f"Unknown dataset implementation: {impl}")
    return None


def dataset_exists(path, impl):
    if impl == 'mmap':
        return MMapIndexedDataset.exists(path)
    else:
        return IndexedDataset.exists(path)


def read_longs(f, n):
    a = np.empty(n, dtype=np.int64)
    f.readinto(a)
    return a


def write_longs(f, a):
    f.write(np.array(a, dtype=np.int64))


dtypes = {
    1: np.uint8,
    2: np.int8,
    3: np.int16,
    4: np.int32,
    5: np.int64,
    6: np.float32,
    7: np.double,
    8: np.uint16,
    9: np.uint32
}


def code(dtype):
    for k in dtypes.keys():
        if dtypes[k] == dtype:
            return k
    raise ValueError(dtype)


def index_file_path(prefix_path):
    return prefix_path + '.idx'


def data_file_path(prefix_path):
    return prefix_path + '.bin'


def create_doc_idx(sizes):
    doc_idx = [0]
    for i, s in enumerate(sizes):
        if s == 0:
            doc_idx.append(i + 1)
    return doc_idx


class IndexedDataset(torch.utils.data.Dataset):
    """Loader for IndexedDataset"""
    _HDR_MAGIC = b'TNTIDX\x00\x00'

    def __init__(self, path):
        super().__init__()
        self.path = path
        self.data_file = None
        self.read_index(path)

    def read_index(self, path):
        with open(index_file_path(path), 'rb') as f:
            magic = f.read(8)
            assert magic == self._HDR_MAGIC, (
                'Index file doesn\'t match expected format. '
                'Make sure that --dataset-impl is configured properly.'
            )
            version = f.read(8)
            assert struct.unpack('<Q', version) == (1,)
            code, self.element_size = struct.unpack('<QQ', f.read(16))
            self.dtype = dtypes[code]
            self._len, self.s = struct.unpack('<QQ', f.read(16))
            self.doc_count = struct.unpack('<Q', f.read(8))
            self.dim_offsets = read_longs(f, self._len + 1)
            self.data_offsets = read_longs(f, self._len + 1)
            self.sizes = read_longs(f, self.s)
            self.doc_idx = read_longs(f, self.doc_count)

    def read_data(self, path):
        self.data_file = open(data_file_path(path), 'rb', buffering=0)

    def check_index(self, i):
        if i < 0 or i >= self._len:
            raise IndexError('index out of range')

    def __del__(self):
        if self.data_file:
            self.data_file.close()

    # @lru_cache(maxsize=8)
    def __getitem__(self, idx):
        if not self.data_file:
            self.read_data(self.path)
        if isinstance(idx, int):
            i = idx
            self.check_index(i)
            tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
            a = np.empty(tensor_size, dtype=self.dtype)
            self.data_file.seek(self.data_offsets[i] * self.element_size)
            self.data_file.readinto(a)
            return a
        elif isinstance(idx, slice):
            start, stop, step = idx.indices(len(self))
            if step != 1:
                raise ValueError("Slices into indexed_dataset must be contiguous")
            sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]]
            size = sum(sizes)
            a = np.empty(size, dtype=self.dtype)
            self.data_file.seek(self.data_offsets[start] * self.element_size)
            self.data_file.readinto(a)
            offsets = list(accumulate(sizes))
            sents = np.split(a, offsets[:-1])
            return sents

    def __len__(self):
        return self._len

    def num_tokens(self, index):
        return self.sizes[index]

    def size(self, index):
        return self.sizes[index]

    @staticmethod
    def exists(path):
        return (
            os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
        )

    @property
    def supports_prefetch(self):
        return False  # avoid prefetching to save memory


class IndexedCachedDataset(IndexedDataset):

    def __init__(self, path):
        super().__init__(path)
        self.cache = None
        self.cache_index = {}

    @property
    def supports_prefetch(self):
        return True

    def prefetch(self, indices):
        if all(i in self.cache_index for i in indices):
            return
        if not self.data_file:
            self.read_data(self.path)
        indices = sorted(set(indices))
        total_size = 0
        for i in indices:
            total_size += self.data_offsets[i + 1] - self.data_offsets[i]
        self.cache = np.empty(total_size, dtype=self.dtype)
        ptx = 0
        self.cache_index.clear()
        for i in indices:
            self.cache_index[i] = ptx
            size = self.data_offsets[i + 1] - self.data_offsets[i]
            a = self.cache[ptx: ptx + size]
            self.data_file.seek(self.data_offsets[i] * self.element_size)
            self.data_file.readinto(a)
            ptx += size
        if self.data_file:
            # close and delete data file after prefetch so we can pickle
            self.data_file.close()
            self.data_file = None

    # @lru_cache(maxsize=8)
    def __getitem__(self, idx):
        if isinstance(idx, int):
            i = idx
            self.check_index(i)
            tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
            a = np.empty(tensor_size, dtype=self.dtype)
            ptx = self.cache_index[i]
            np.copyto(a, self.cache[ptx: ptx + a.size])
            return a
        elif isinstance(idx, slice):
            # Hack just to make this work, can optimizer later if necessary
            sents = []
            for i in range(*idx.indices(len(self))):
                sents.append(self[i])
            return sents


class IndexedDatasetBuilder(object):
    element_sizes = {
        np.uint8: 1,
        np.int8: 1,
        np.int16: 2,
        np.int32: 4,
        np.int64: 8,
        np.float32: 4,
        np.double: 8
    }

    def __init__(self, out_file, dtype=np.int32):
        self.out_file = open(out_file, 'wb')
        self.dtype = dtype
        self.data_offsets = [0]
        self.dim_offsets = [0]
        self.sizes = []
        self.element_size = self.element_sizes[self.dtype]
        self.doc_idx = [0]

    def add_item(self, tensor):
        bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype))
        self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
        for s in tensor.size():
            self.sizes.append(s)
        self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))

    def end_document(self):
        self.doc_idx.append(len(self.sizes))

    def merge_file_(self, another_file):
        index = IndexedDataset(another_file)
        assert index.dtype == self.dtype

        begin = self.data_offsets[-1]
        for offset in index.data_offsets[1:]:
            self.data_offsets.append(begin + offset)
        self.sizes.extend(index.sizes)
        begin = self.dim_offsets[-1]
        for dim_offset in index.dim_offsets[1:]:
            self.dim_offsets.append(begin + dim_offset)

        with open(data_file_path(another_file), 'rb') as f:
            while True:
                data = f.read(1024)
                if data:
                    self.out_file.write(data)
                else:
                    break

    def finalize(self, index_file):
        self.out_file.close()
        index = open(index_file, 'wb')
        index.write(b'TNTIDX\x00\x00')
        index.write(struct.pack('<Q', 1))
        index.write(struct.pack('<QQ', code(self.dtype), self.element_size))
        index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes)))
        index.write(struct.pack('<Q', len(self.doc_idx)))
        write_longs(index, self.dim_offsets)
        write_longs(index, self.data_offsets)
        write_longs(index, self.sizes)
        write_longs(index, self.doc_idx)
        index.close()


def _warmup_mmap_file(path):
    with open(path, 'rb') as stream:
        while stream.read(100 * 1024 * 1024):
            pass


class MMapIndexedDataset(torch.utils.data.Dataset):
    class Index(object):
        _HDR_MAGIC = b'MMIDIDX\x00\x00'

        @classmethod
        def writer(cls, path, dtype):
            class _Writer(object):
                def __enter__(self):
                    self._file = open(path, 'wb')

                    self._file.write(cls._HDR_MAGIC)
                    self._file.write(struct.pack('<Q', 1))
                    self._file.write(struct.pack('<B', code(dtype)))

                    return self

                @staticmethod
                def _get_pointers(sizes):
                    dtype_size = dtype().itemsize
                    address = 0
                    pointers = []

                    for size in sizes:
                        pointers.append(address)
                        address += size * dtype_size

                    return pointers

                def write(self, sizes, doc_idx):
                    pointers = self._get_pointers(sizes)

                    self._file.write(struct.pack('<Q', len(sizes)))
                    self._file.write(struct.pack('<Q', len(doc_idx)))

                    sizes = np.array(sizes, dtype=np.int32)
                    self._file.write(sizes.tobytes(order='C'))
                    del sizes

                    pointers = np.array(pointers, dtype=np.int64)
                    self._file.write(pointers.tobytes(order='C'))
                    del pointers

                    doc_idx = np.array(doc_idx, dtype=np.int64)
                    self._file.write(doc_idx.tobytes(order='C'))

                def __exit__(self, exc_type, exc_val, exc_tb):
                    self._file.close()

            return _Writer()

        def __init__(self, path, skip_warmup=False):
            with open(path, 'rb') as stream:
                magic_test = stream.read(9)
                assert self._HDR_MAGIC == magic_test, (
                    'Index file doesn\'t match expected format. '
                    'Make sure that --dataset-impl is configured properly.'
                )
                version = struct.unpack('<Q', stream.read(8))
                assert (1,) == version

                dtype_code, = struct.unpack('<B', stream.read(1))
                self._dtype = dtypes[dtype_code]
                self._dtype_size = self._dtype().itemsize

                self._len = struct.unpack('<Q', stream.read(8))[0]
                self._doc_count = struct.unpack('<Q', stream.read(8))[0]
                offset = stream.tell()

            if not skip_warmup:
                print("    warming up index mmap file...")
                _warmup_mmap_file(path)

            self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
            self._bin_buffer = memoryview(self._bin_buffer_mmap)
            print("    reading sizes...")
            self._sizes = np.frombuffer(
                self._bin_buffer,
                dtype=np.int32,
                count=self._len,
                offset=offset)
            print("    reading pointers...")
            self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
                                           offset=offset + self._sizes.nbytes)
            print("    reading document index...")
            self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
                                          offset=offset + self._sizes.nbytes + self._pointers.nbytes)

        def __del__(self):
            self._bin_buffer_mmap._mmap.close()
            del self._bin_buffer_mmap

        @property
        def dtype(self):
            return self._dtype

        @property
        def sizes(self):
            return self._sizes

        @property
        def doc_idx(self):
            return self._doc_idx

        @lru_cache(maxsize=8)
        def __getitem__(self, i):
            return self._pointers[i], self._sizes[i]

        def __len__(self):
            return self._len

    def __init__(self, path, skip_warmup=False):
        super().__init__()

        self._path = None
        self._index = None
        self._bin_buffer = None

        self._do_init(path, skip_warmup)

    def __getstate__(self):
        return self._path

    def __setstate__(self, state):
        self._do_init(state)

    def _do_init(self, path, skip_warmup):
        self._path = path
        self._index = self.Index(index_file_path(self._path), skip_warmup)

        if not skip_warmup:
            print("    warming up data mmap file...")
            _warmup_mmap_file(data_file_path(self._path))
        print("    creating numpy buffer of mmap...")
        self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
        print("    creating memory view of numpy buffer...")
        self._bin_buffer = memoryview(self._bin_buffer_mmap)

    def __del__(self):
        self._bin_buffer_mmap._mmap.close()
        del self._bin_buffer_mmap
        del self._index

    def __len__(self):
        return len(self._index)

    # @lru_cache(maxsize=8)
    def __getitem__(self, idx):
        if isinstance(idx, int):
            assert idx < len(self._index), "Index {} out of range: {}".format(idx, len(self._index))
            ptr, size = self._index[idx]
            np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
                                     count=size, offset=ptr)
            return np_array
        elif isinstance(idx, slice):
            start, stop, step = idx.indices(len(self))
            if step != 1:
                raise ValueError("Slices into indexed_dataset must be contiguous")
            ptr = self._index._pointers[start]
            sizes = self._index._sizes[idx]
            offsets = list(accumulate(sizes))
            total_size = sum(sizes)
            np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
                                     count=total_size, offset=ptr)
            sents = np.split(np_array, offsets[:-1])
            return sents

    def get(self, idx, offset=0, length=None):
        """ Retrieves a single item from the dataset with the option to only
        return a portion of the item.

        get(idx) is the same as [idx] but get() does not support slicing.
        """
        ptr, size = self._index[idx]
        if length is None:
            length = size - offset
        ptr += offset * np.dtype(self._index.dtype).itemsize
        np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
                                 count=length, offset=ptr)
        return np_array

    @property
    def sizes(self):
        return self._index.sizes

    # @property
    # def doc_idx(self):
    #     return self._index.doc_idx

    # def get_doc_idx(self):
    #     return self._index._doc_idx

    # def set_doc_idx(self, doc_idx_):
    #     self._index._doc_idx = doc_idx_

    @property
    def supports_prefetch(self):
        return False

    @staticmethod
    def exists(path):
        return (
            os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
        )


class MMapIndexedDatasetBuilder(object):
    def __init__(self, out_file, dtype=np.int64):
        self._data_file = open(out_file, 'wb')
        self._dtype = dtype
        self._sizes = []
        self._doc_idx = [0]

    def add_item(self, tensor):
        np_array = np.array(tensor.numpy(), dtype=self._dtype)
        self._data_file.write(np_array.tobytes(order='C'))
        self._sizes.append(np_array.size)

    def end_document(self):
        self._doc_idx.append(len(self._sizes))

    def merge_file_(self, another_file):
        # Concatenate index
        index = MMapIndexedDataset.Index(index_file_path(another_file))
        assert index.dtype == self._dtype

        for size in index.sizes:
            self._sizes.append(size)

        # Concatenate data
        with open(data_file_path(another_file), 'rb') as f:
            shutil.copyfileobj(f, self._data_file)

    def finalize(self, index_file):
        self._data_file.close()

        with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
            index.write(self._sizes, self._doc_idx)


================================================
FILE: data_utils/lm_datasets.py
================================================
import random
import torch
import os
import json
import pickle
import numpy as np
from torch.utils.data import Dataset
from .distributed_indexed import DistributedMMapIndexedDataset

from torch.distributed import get_rank, get_world_size, barrier
from utils import print_rank
from utils import save_rank


class LMTrainDataset(Dataset):
    def __init__(self, args, tokenizer, path, split, num, ratio, rng_sample: random.Random):
        self.args = args
        self.tokenizer = tokenizer
        self.split = split
        self.pad_id = self.tokenizer.eos_token_id
        self.ratio = ratio
        self.max_length = args.max_length
        self.max_prompt_length = args.max_prompt_length
        self.rng_sample = rng_sample
        self.lm_ctx = DistributedMMapIndexedDataset(path, f"{split}", get_rank(), get_world_size())

        if os.path.exists(os.path.join(path, f"{split}.jsonl")):
            with open(os.path.join(path, f"{split}.jsonl")) as f:
                self.raw = [json.loads(line) for line in f.readlines()]
                self.answers = [x["output"] if isinstance(x["output"], list) else [x["output"]] for x in self.raw]
        
        print_rank(len(self.lm_ctx))
        if num == -1:
            self.num = len(self.lm_ctx)
        else:
            self.num = num

        print_rank(f"Num LM instances: {len(self.lm_ctx)}")

    def __len__(self):
        return self.num
   
    def __getitem__(self, index):
        return self._get_lm(index)
    
    def _get_lm(self, index):
        data = self.lm_ctx[index]
        input_ids = data.astype(int)
        return {
            "input_ids": input_ids
        }

    def _process_lm(self, i, samp, model_data, no_model_data, gen_data):
        input_ids = samp["input_ids"]
        source_len = 1
        
        prompt = None
        if 65535 in input_ids:
            source_len = np.where(input_ids==65535)[0][0]
            prompt = input_ids[:source_len]
            input_ids = np.concatenate([input_ids[:source_len], input_ids[source_len+1:]], axis=0)
        input_ids = input_ids[:self.max_length]
        input_len = len(input_ids)
        model_data["input_ids"][i][:input_len-1] = torch.tensor(input_ids[:-1], dtype=torch.long)
        model_data["attention_mask"][i][:input_len-1] = 1.0
        if self.args.model_type in ["gpt2"]:
            model_data["position_ids"][i][:input_len-1] = torch.arange(0, input_len-1, dtype=torch.long)
        no_model_data["label"][i][:input_len-1] = torch.tensor(input_ids[1:], dtype=torch.long)
        no_model_data["label"][i][:source_len-1] = -100
        no_model_data["loss_mask"][i][:input_len-1] = 1.0
        no_model_data["loss_mask"][i][:source_len-1] = 0
        
        if prompt is not None:
            gen_data["input_ids"][i][-len(prompt):] = torch.tensor(prompt, dtype=torch.long)
            gen_data["attention_mask"][i][-len(prompt):] = 1.0

    def move_to_device(self, model_data, no_model_data, gen_data, device):
        for k in model_data:
            model_data[k] = model_data[k].to(device)

        for k in no_model_data:
            no_model_data[k] = no_model_data[k].to(device)

        for k in gen_data:
            gen_data[k] = gen_data[k].to(device)

        return model_data, no_model_data, gen_data

    def collate(self, samples):
        bs = len(samples)

        max_length = self.max_length
        
        model_data = {
            "input_ids": torch.ones(bs, max_length, dtype=torch.long) * self.pad_id,
            "attention_mask": torch.zeros(bs, max_length),
        }
        
        if self.args.model_type in ["gpt2"]:
            model_data["position_ids"] = torch.zeros(bs, max_length, dtype=torch.long)
            
        no_model_data = {
            "label": torch.ones(bs, max_length, dtype=torch.long) * -100,
            "loss_mask": torch.zeros(bs, max_length)
        }
        
        gen_data = {
            "input_ids": torch.ones(bs, self.max_prompt_length, dtype=torch.long) * self.pad_id,
            "attention_mask": torch.zeros(bs, self.max_prompt_length, dtype=torch.long),
        }

        for i, samp in enumerate(samples):
            self._process_lm(i, samp, model_data, no_model_data, gen_data)
        
        return model_data, no_model_data, gen_data


================================================
FILE: data_utils/prompt_datasets.py
================================================
import random
import torch
import os
from torch.utils.data import Dataset
from .distributed_indexed import DistributedMMapIndexedDataset

from torch.distributed import get_rank, get_world_size
from utils import print_rank
from tqdm import tqdm
import json


class PromptDataset(Dataset):
    def __init__(self, args, tokenizer, split, data_path=None, num=-1):
        super().__init__()
        self.tokenizer = tokenizer

        self.args = args
        self.tokenizer = tokenizer
        self.split = split
        self.pad_id = self.tokenizer.eos_token_id
        self.max_length = args.max_length
        self.min_prompt_length = args.min_prompt_length
        self.max_prompt_length = args.max_prompt_length

        if args.bin_data:
            self.data = DistributedMMapIndexedDataset(data_path, f"{split}", get_rank(), get_world_size())
        elif args.json_data:
            self.data, self.origin_data = self.load_data_json(data_path)
        else:
            # txt data
            self.data = self.load_data_txt(data_path)
        
        if os.path.exists(os.path.join(data_path, f"{self.split}_{self.args.model_type}.jsonl")):
            with open(os.path.join(data_path, f"{self.split}_{self.args.model_type}.jsonl")) as f:
                self.raw = [json.loads(line) for line in f.readlines()]
                self.answers = [x["output"] if isinstance(x["output"], list) else [x["output"]] for x in self.raw]
        elif os.path.exists(os.path.join(data_path, f"{split}.jsonl")):
            with open(os.path.join(data_path, f"{split}.jsonl")) as f:
                self.raw = [json.loads(line) for line in f.readlines()]
                self.answers = [x["output"] if isinstance(x["output"], list) else [x["output"]] for x in self.raw]
        else:
            print_rank("WARNING: No answers exist")
            
        self.label_map = {tokenizer.encode(x[0], add_special_tokens=False)[0]: x[0] for x in self.answers}
            
        self.num = min(num, len(self.data)) if num > 0 else len(self.data)
        print_rank(f"Num instances: {len(self.data)}")
            
    def __len__(self):
        return self.num

    def load_data_json(self, data_path):
        if os.path.exists(os.path.join(data_path, f"{self.split}_{self.args.model_type}.jsonl")):
            data_path = os.path.join(data_path, f"{self.split}_{self.args.model_type}.jsonl")
        else:
            data_path = os.path.join(data_path, f"{self.split}.jsonl")
        
        with open(data_path) as f:
            lines = f.readlines()
        data_origin = [json.loads(line) for line in lines]
        data = []
        print_rank("Loading Data")
        for d in tqdm(data_origin, disable=(get_rank() != 0)):
            prompt = d["prompt"].replace("<n>", "\n")
            prompt_ids = self.tokenizer.encode(prompt)
            output_ids = None
            if "output" in d:
                if isinstance(d["output"], list):
                    output_ids = self.tokenizer.encode(d["output"][0])
                else:
                    output_ids = self.tokenizer.encode(d["output"])
            data.append({
                "prompt_ids": prompt_ids,
                "output_ids": output_ids[:self.max_length - self.max_prompt_length]
            })
        print_rank("Load End")
        return data, data_origin

    def load_data_txt(self, data_path):
        with open(os.path.join(data_path, f"{self.split}.txt")) as f:
            lines = f.readlines()
        data = []
        print_rank("Loading Data")
        for line in lines:
            line = line.strip()
            line = line.replace("<n>", "\n")
            prompt = self.tokenizer.encode(line)
            data.append(prompt)
        print_rank("Load End")
        return data

    def verbalizer(self):
        return self.label_map

    def __getitem__(self, index: int):
        data = self.data[index]
        if self.args.bin_data:
            data = data.astype(int)
        elif self.args.json_data:
            output_ids = data["output_ids"]
            data = data["prompt_ids"]
        
        prompt_length = self.max_prompt_length

        prompt = data[:prompt_length]
        rest = data[prompt_length:]  
        if self.args.json_data:
            if output_ids is not None:
                rest = output_ids  
    
        return index, prompt, rest
    
    def collate(self, samples):
        bs = len(samples)
        
        max_prompt_length = self.max_prompt_length
        max_rest_length = max([len(samp[2]) for samp in samples])
        
        model_batch = {
            "input_ids": torch.ones(bs, max_prompt_length, dtype=torch.long) * self.pad_id,
            "attention_mask": torch.zeros(bs, max_prompt_length, dtype=torch.long),
            # "position_ids": torch.zeros(bs, max_prompt_length, dtype=torch.long)
        }
        
        no_model_batch = {
            "idx": torch.zeros(bs, dtype=torch.long),
            "rest_ids": torch.ones(bs, max_rest_length, dtype=torch.long) * self.pad_id
        }
        
        for i, (idx, prompt, rest) in enumerate(samples):
            # left padding
            model_batch["input_ids"][i][-len(prompt):] = torch.tensor(prompt, dtype=torch.long)
            model_batch["attention_mask"][i][-len(prompt):] = 1
            # model_batch["position_ids"][i][-len(prompt):] = torch.arange(len(prompt))
            no_model_batch["idx"][i] = idx
            no_model_batch["rest_ids"][i][:len(rest)] = torch.tensor(rest, dtype=torch.long)
        
        return model_batch, no_model_batch

    def move_to_device(self, model_batch, no_model_batch, device):
        for k in model_batch:
            model_batch[k] = model_batch[k].to(device)        
        for k in no_model_batch:
            no_model_batch[k] = no_model_batch[k].to(device)    
        
        return model_batch, no_model_batch


================================================
FILE: distillm/__init__.py
================================================
from .losses import forward_kl, reverse_kl, symmetric_kl, js_distance, tv_distance
from .losses import skewed_forward_kl, skewed_reverse_kl
from .sampler import SampleGenerator
from .buffer import ReplayBuffer


================================================
FILE: distillm/buffer.py
================================================
import random
import torch
import os
import json
import pickle
import numpy as np
from torch.utils.data import Dataset

from torch.distributed import get_rank, get_world_size, barrier
from utils import print_rank
from utils import save_rank

from collections import namedtuple, deque


class ReplayBuffer:
    def __init__(self, args):
        self.args = args
        self.replay_memory = deque(maxlen=args.capacity)
        self.bs = args.batch_size
        if args.model_type in ["gpt2", "llama"]:
            self.data = namedtuple("Generation", \
               field_names=["input_ids", "attention_mask", "position_ids", "label", "loss_mask"])
        else:
            self.data = namedtuple("Generation", \
               field_names=["input_ids", "attention_mask", "label", "loss_mask"])
            
    def __len__(self):
        return len(self.replay_memory)
    
    def sample(self):
        data = random.sample(self.replay_memory, k=self.bs)
        input_ids = torch.stack([d.input_ids for d in data], dim=0)
        attention_mask = torch.stack([d.attention_mask for d in data], dim=0)
        label = torch.stack([d.label for d in data], dim=0)
        loss_mask = torch.stack([d.loss_mask for d in data], dim=0)
        
        if self.args.model_type in ["gpt2", "llama"]:
            position_ids = torch.stack([d.position_ids for d in data], dim=0)
            model_data = {
                "input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids
            }
        else:
            model_data = {
                "input_ids": input_ids, "attention_mask": attention_mask
            }
            
        no_model_data = {
            "label": label, "loss_mask": loss_mask
        }
        return model_data, no_model_data
        
    
    def move_to_device(self, model_data, no_model_data, device):
        for k in model_data:
            model_data[k] = model_data[k].to(device)

        for k in no_model_data:
            no_model_data[k] = no_model_data[k].to(device)

        return model_data, no_model_data
    
    def move_to_memory(self, model_data, no_model_data):
        device = torch.device("cpu")
        model_data_cpu, no_model_data_cpu = {}, {}
        for k in model_data:
            model_data_cpu[k] = model_data[k].to(device)
        
        for k in no_model_data:
            no_model_data_cpu[k] = no_model_data[k].to(device)
        
        for idx in range(model_data_cpu["input_ids"].size(0)):
            if self.args.model_type in ["gpt2", "llama"]:
                e = self.data(model_data_cpu["input_ids"][idx], model_data_cpu["attention_mask"][idx], model_data_cpu["position_ids"][idx],
                              no_model_data_cpu["label"][idx], no_model_data_cpu["loss_mask"][idx])
            else:
                e = self.data(model_data_cpu["input_ids"][idx], model_data_cpu["attention_mask"][idx],
                              no_model_data_cpu["label"][idx], no_model_data_cpu["loss_mask"][idx])
            self.replay_memory.append(e)

================================================
FILE: distillm/losses.py
================================================
import torch
import torch.nn.functional as F

def forward_kl(logits, teacher_logits, no_model_batch):
    teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
    inf_mask = torch.isinf(logits)
    student_logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
    prod_probs = torch.masked_fill(teacher_probs * student_logprobs, inf_mask, 0)
    x = torch.sum(prod_probs, dim=-1).view(-1)
    mask = (no_model_batch["label"] != -100).int()
    distil_loss = -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
    return distil_loss

def reverse_kl(logits, teacher_logits, no_model_batch):
    student_probs = F.softmax(logits, dim=-1, dtype=torch.float32)
    student_logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
    teacher_logprobs = F.log_softmax(teacher_logits, dim=-1, dtype=torch.float32)
    inf_mask = torch.isinf(teacher_logits) | torch.isinf(logits)
    prod_probs = torch.masked_fill(student_probs * teacher_logprobs, inf_mask, 0)
    prod_probs -= torch.masked_fill(student_probs * student_logprobs, inf_mask, 0)
    x = torch.sum(prod_probs, dim=-1).view(-1)
    mask = (no_model_batch["label"] != -100).int()
    distil_loss = -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
    return distil_loss

def symmetric_kl(logits, teacher_logits, no_model_batch, lam=0.9):
    for_kl = forward_kl(logits, teacher_logits, no_model_batch)
    rev_kl = reverse_kl(logits, teacher_logits, no_model_batch)
    distil_loss = (1-lam) * for_kl + lam * rev_kl
    return distil_loss
    
def js_distance(logits, teacher_logits, no_model_batch, lam=0.9):
    teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
    student_probs = F.softmax(logits, dim=-1, dtype=torch.float32)
    mixed_probs = (1-lam) * teacher_probs + lam * student_probs

    teacher_logprobs = F.log_softmax(teacher_logits, dim=-1, dtype=torch.float32)
    student_logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
    mixed_logprobs = torch.log(mixed_probs)

    mask = (no_model_batch["label"] != -100).int()
    inf_mask = torch.isinf(logits) | torch.isinf(teacher_logits)

    prod_probs = torch.masked_fill(student_probs * mixed_logprobs, inf_mask, 0)
    prod_probs -= torch.masked_fill(student_probs * student_logprobs, inf_mask, 0)
    x = torch.sum(prod_probs, dim=-1).view(-1)
    distil_loss = lam * -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)

    prod_probs = torch.masked_fill(teacher_probs * mixed_logprobs, inf_mask, 0)
    prod_probs -= torch.masked_fill(teacher_probs * teacher_logprobs, inf_mask, 0)
    x = torch.sum(prod_probs, dim=-1).view(-1)
    distil_loss += (1-lam) * -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
    return distil_loss
    
def tv_distance(logits, teacher_logits, no_model_batch):
    teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
    student_probs = F.softmax(logits, dim=-1, dtype=torch.float32)
    
    mask = (no_model_batch["label"] != -100).int()
    inf_mask = torch.isinf(logits) | torch.isinf(teacher_logits)
    prod_probs = 0.5 * torch.masked_fill(torch.abs(teacher_probs - student_probs), inf_mask, 0)
    x = torch.sum(prod_probs, dim=-1).view(-1)
    distil_loss = torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
    return distil_loss

def skewed_forward_kl(logits, teacher_logits, no_model_batch, lam=0.1):
    teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
    student_probs = F.softmax(logits, dim=-1, dtype=torch.float32)
    mixed_probs = lam * teacher_probs + (1-lam) * student_probs
    mixed_logprobs = torch.log(mixed_probs)
    
    mask = (no_model_batch["label"] != -100).int()
    inf_mask = torch.isinf(logits) | torch.isinf(teacher_logits)

    prod_probs = torch.masked_fill(teacher_probs * mixed_logprobs, inf_mask, 0)
    x = torch.sum(prod_probs, dim=-1).view(-1)
    distil_loss = -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
    return distil_loss

def skewed_reverse_kl(logits, teacher_logits, no_model_batch, lam=0.1):
    teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
    student_probs = F.softmax(logits, dim=-1, dtype=torch.float32)
    mixed_probs = (1-lam) * teacher_probs + lam * student_probs
    
    student_logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
    mixed_logprobs = torch.log(mixed_probs)

    mask = (no_model_batch["label"] != -100).int()
    inf_mask = torch.isinf(logits) | torch.isinf(teacher_logits)

    prod_probs = torch.masked_fill(student_probs * mixed_logprobs, inf_mask, 0)
    prod_probs -= torch.masked_fill(student_probs * student_logprobs, inf_mask, 0)
    x = torch.sum(prod_probs, dim=-1).view(-1)
    distil_loss = -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
    return distil_loss

================================================
FILE: distillm/sampler.py
================================================
import torch
import os
from transformers import GenerationConfig


class SampleGenerator():
    def __init__(self, args, tokenizer):
        self.args = args
        self.tokenizer = tokenizer
        self.max_new_token = self.args.max_length - self.args.max_prompt_length
        self.pad_id = tokenizer.pad_token_id
        self.generation_config = GenerationConfig(
            do_sample=args.do_sample,
            top_p=args.gen_top_p,
            top_k=args.top_k,
            temperature=args.temperature,
            repetition_penalty=args.repetition_penalty,
            max_length=args.max_length,
            min_length=None,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id,
            return_dict_in_generate=True,
            output_scores=False
        )
        
    def run_sample(self, model, gen_data):
        bs = gen_data["input_ids"].size(0)
        results = {
            "input_ids": torch.ones(bs, self.args.max_length, dtype=torch.long, device=gen_data["input_ids"].device) * self.pad_id,
            "attention_mask": torch.zeros(bs, self.args.max_length, dtype=torch.float,  device=gen_data["input_ids"].device),
            "position_ids": torch.zeros(bs, self.args.max_length, dtype=torch.long,  device=gen_data["input_ids"].device),
            "no_model_batch": torch.ones(bs, self.args.max_length, dtype=torch.long, device=gen_data["input_ids"].device) * -100,
        }
        
        model.eval()
        with torch.no_grad():
            gen_out = model.generate(
                **gen_data,
                generation_config=self.generation_config,
                max_new_tokens=self.max_new_token,
            )
            
            full_ids = gen_out.sequences
            input_ids = full_ids[:, :gen_data["input_ids"].size(1)]
            response_ids = full_ids[:, gen_data["input_ids"].size(1):]
            
            for i in range(len(input_ids)):
                result_id = torch.cat(
                    (input_ids[i][input_ids[i] != self.pad_id],
                     response_ids[i][response_ids[i] != self.pad_id]),
                )
                input_id = input_ids[i][input_ids[i] != self.pad_id]
                response_id = response_ids[i][response_ids[i] != self.pad_id]
                
                results["input_ids"][i, :len(result_id)] = result_id
                results["position_ids"][i, :len(result_id)] = torch.arange(len(result_id))
                results["no_model_batch"][i, len(input_id):len(result_id)] = response_id
        results["attention_mask"] = torch.where(results["input_ids"] != self.pad_id, 1, 0)
        results["attention_mask"] = results["attention_mask"].float()
        results["no_model_batch"] = results["no_model_batch"].long()
        return results

================================================
FILE: evaluate.py
================================================
import time
import os

import torch
import torch.distributed as dist
import deepspeed

import json

from arguments import get_args

from utils import initialize, print_args
from utils import print_rank
from utils import save_rank
from utils import get_tokenizer, get_model

from evaluate_main import evaluate_main, prepare_dataset_main


torch.set_num_threads(4)


def setup_model(args, ds_config, device):
    # get the model
    model = get_model(args, device)
    # get the optimizer and lr_scheduler

    optimizer, lr_scheduler = None, None
        
    model, _, _, _ = deepspeed.initialize(
        model=model,
        optimizer=optimizer,
        args=args,
        lr_scheduler=lr_scheduler,
        mpu=None,
        config_params=ds_config
    )
    
    # get the memory usage
    print_rank("Model mem\n", torch.cuda.memory_summary())
    return model


def main():
    torch.backends.cudnn.enabled = False
    
    args = get_args()
    initialize(args)
    
    if dist.get_rank() == 0:
        print_args(args)
        with open(os.path.join(args.save, "args.json"), "w") as f:
            json.dump(vars(args), f)
    
    device = torch.cuda.current_device()
    cur_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
    save_rank("\n\n" + "="*30 + f" EXP at {cur_time} " + "="*30, os.path.join(args.save, "log.txt"))
    print("OK")
    with open(args.deepspeed_config, "r") as f:
        ds_config = json.load(f)

    ds_config["gradient_accumulation_steps"] = args.gradient_accumulation_steps
    ds_config["train_micro_batch_size_per_gpu"] = args.batch_size
    ds_config["gradient_clipping"] = args.clip_grad
    ds_config["steps_per_print"] = args.gradient_accumulation_steps
    
    if not args.do_train:
        ds_config["zero_optimization"]["stage"] = 0

    args.fp32 = not ds_config["fp16"]["enabled"] 
    args.deepspeed_config = None

    # get the tokenizer
    tokenizer = get_tokenizer(args)
    if args.type == "eval_main":
        dataset = prepare_dataset_main(
            args,
            tokenizer,
        )
    else:
        raise NotImplementedError
    model = setup_model(args, ds_config, device)
    
    if args.type == "eval_main":
        evaluate_main(args, tokenizer, model, dataset["test"], "test", 0, device)
    else:
        raise NotImplementedError
    
    
if __name__ == "__main__":
    main()

================================================
FILE: evaluate_main.py
================================================
from data_utils.prompt_datasets import PromptDataset
from transformers import GenerationConfig
import os
import nltk
nltk.download("punkt")

import torch
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import json
from utils import print_rank, save_rank, all_gather

from rouge_metric import compute_metrics

torch.set_num_threads(4)


def prepare_dataset_main(args, tokenizer):
    data = {}
    data["test"] = PromptDataset(args, tokenizer, "valid", args.data_dir, args.dev_num)

    return data


def run_model(args, tokenizer, model, dataset: PromptDataset, epoch, device):
    
    collate_fn = dataset.collate
    dp_world_size = dist.get_world_size()
    dp_rank = dist.get_rank()
    dp_group = None
    
    sampler = DistributedSampler(dataset, shuffle=False, drop_last=False, rank=dp_rank, num_replicas=dp_world_size)
    dataloader = DataLoader(
        dataset, sampler=sampler, batch_size=args.eval_batch_size, num_workers=args.num_workers, collate_fn=collate_fn)
    model.eval()
    
    all_query_ids = []
    all_response_ids = []
    all_lm_losses = []
    
    generation_config = GenerationConfig (
        do_sample=args.do_sample,
        top_p=args.top_p,
        top_k=args.top_k,
        temperature=args.temperature,
        no_repeat_ngram_size=args.no_repeat_ngram_size,
        repetition_penalty=args.repetition_penalty,
        max_length=args.max_length,
        min_length=None,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        return_dict_in_generate=True,
        output_scores=True
    )

    with torch.no_grad():
        for it, (model_batch, no_model_batch) in enumerate(tqdm(dataloader, desc=f"Evaluating {args.data_names} ", disable=(dist.get_rank() != 0))):
            if it == 0:
                print_rank("############### Example ###############")
                print_rank(tokenizer.decode(model_batch["input_ids"][0], skip_special_tokens=True))
                print_rank("############### End ###############")
            
            dataset.move_to_device(model_batch, no_model_batch, device)

            all_ids = torch.cat([model_batch["input_ids"], no_model_batch["rest_ids"]], dim=-1)
            input_ids = all_ids[:, :-1]
            attention_mask = (input_ids != tokenizer.pad_token_id).long()
            label_ids = all_ids[:, 1:]
            label_ids = torch.masked_fill(label_ids, label_ids==tokenizer.pad_token_id, -100)
            label_ids[:, :model_batch["input_ids"].size(1)-1] = -100  
            if args.model_type in ["gpt2"]:
                position_ids = (torch.cumsum(attention_mask, dim=-1) - 1) * attention_mask
                out = model(input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask, return_dict=True)
            else:
                out = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
            logits = out.logits
            loss_mask = (label_ids != -100).float()
            loss_func = nn.CrossEntropyLoss(reduction="none")
            lm_loss = loss_func(logits.view(-1, logits.size(-1)), label_ids.view(-1)).view(label_ids.size())
            lm_loss = torch.sum(lm_loss * loss_mask, -1) / torch.sum(loss_mask, -1)
            all_lm_losses.append(lm_loss)

            query_ids = model_batch["input_ids"]
            max_new_tokens = args.max_length - query_ids.size(1)
            gen_out = model.generate(
                **model_batch,
                generation_config=generation_config,
                max_new_tokens=max_new_tokens
            )
            full_ids = gen_out.sequences
            response_ids = full_ids[:, query_ids.size(1):] # remove prompt (may include start token)
            
            query_ids = F.pad(query_ids, (args.max_prompt_length-query_ids.size(1), 0, 0, 0), value=tokenizer.pad_token_id)
            response_ids = F.pad(response_ids, (0, args.max_length-args.max_prompt_length-response_ids.size(1), 0, 0), value=tokenizer.pad_token_id)
            
            all_query_ids.append(query_ids)
            all_response_ids.append(response_ids)

    all_lm_losses = torch.cat(all_lm_losses)
    mean_lm_loss = all_lm_losses.mean()
    dist.all_reduce(mean_lm_loss, dist.ReduceOp.SUM, group=dp_group)
    mean_lm_loss = mean_lm_loss.item() / dp_world_size
        
    all_query_ids = torch.cat(all_query_ids)
    all_query_ids = all_gather(all_query_ids, dim=1, group=dp_group, world_size=dp_world_size, op="stack")
    all_query_ids = all_query_ids.view(-1, all_query_ids.size(-1))
    all_query_ids = all_query_ids[:len(dataset)]
    
    all_response_ids = torch.cat(all_response_ids)
    all_response_ids = all_gather(all_response_ids, dim=1, group=dp_group, world_size=dp_world_size, op="stack")
    all_response_ids = all_response_ids.view(-1, all_response_ids.size(-1))
    all_response_ids = all_response_ids[:len(dataset)]
        
    return (
        mean_lm_loss,
        all_query_ids,
        all_response_ids)


def evaluate_main(args, tokenizer, model, dataset: PromptDataset, split, epoch, device):
        
    lm_loss, query_ids, response_ids = run_model(args, tokenizer, model, dataset, epoch, device)
    query_strs = tokenizer.batch_decode(query_ids, skip_special_tokens=True)
    response_strs = tokenizer.batch_decode(response_ids, skip_special_tokens=True)
    
    with open(os.path.join(args.save, "preds.txt"), "w") as f:
        for q, r in zip(query_strs, response_strs):
            f.write(q.replace("\n", "<n>") + "\t\t" + r.replace("\n", "<n>") + "\n")

    all_preds = [[]]
    for q, r in zip(query_strs, response_strs):
        all_preds[0].append((q, q + r))
    torch.save(all_preds, os.path.join(args.save, "preds.pt"))

    all_responses = []
    with open(os.path.join(args.save, "answers.jsonl"), "w") as f:    
        for p in all_preds[0]:
            q, r = p
            r = r[len(q):]
            idx = r.find("<|endoftext|>")
            if idx >= 0:
                r = r[:idx]
            f.write(json.dumps({
                "text": r.replace("<n>", "\n").strip()
            }) + "\n")
            all_responses.append(r.replace("<n>", "\n").strip())
    
    gen_res = compute_metrics(all_responses, dataset.answers)

    mean_gen_length = np.mean([len(tokenizer.encode(s)) for s in response_strs])

    log_str = f"{split} | name: {args.data_names} | {gen_res} | lm_loss {round(lm_loss, 4)} | avg. gen lenth: {mean_gen_length}"
    print_rank(log_str)
    save_rank(log_str, os.path.join(args.save, "log.txt"))


================================================
FILE: finetune.py
================================================
import time
import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from torch.optim import AdamW
import deepspeed

import random
import json
from tqdm import tqdm
import math
import datetime

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoConfig,
    GenerationConfig)

from transformers import get_constant_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup
from torch.optim.lr_scheduler import CosineAnnealingLR

from arguments import get_args

from data_utils.lm_datasets import LMTrainDataset
from utils import get_optimizer_params, get_optimizer_params_peft, print_args, initialize
from utils import print_rank, get_rank
from utils import save_rank
from utils import all_gather
from utils import load_parallel, save_parallel
from utils import get_tokenizer, get_model

from distillm import forward_kl, reverse_kl, js_distance, tv_distance
from distillm import skewed_forward_kl, skewed_reverse_kl
from distillm import SampleGenerator, ReplayBuffer

from rouge_metric import compute_metrics

from peft import PeftModel

torch.set_num_threads(4)


def get_teacher_model(args, device):
    config = AutoConfig.from_pretrained(args.teacher_model_path)
    if args.model_parallel:
        raise NotImplementedError
    else:
        config.is_model_parallel = False
        try: model = AutoModelForCausalLM.from_pretrained(args.teacher_model_path, config=config, device_map={"": device}, torch_dtype=torch.float16)
        except:
            model = AutoModelForCausalLM.from_pretrained(args.teacher_model_path, config=config, device_map={"": device}, torch_dtype=torch.float32)
            model = model.half()
        
        if args.peft is not None and args.teacher_peft_path is not None:
            if args.peft == "lora":
                model = PeftModel.from_pretrained(model, args.teacher_peft_path)
                model = model.merge_and_unload()
            else:
                raise NotImplementedError
        else:
            if dist.get_rank() == 0:
                print(' > number of parameters: {}'.format(
                    sum([p.nelement() for p in model.parameters()])), flush=True)

    model.eval()
    
    return model


def get_optimizer(args, model):
    """Set up the optimizer."""

    # Build parameter groups (weight decay and non-decay).
    while isinstance(model, DDP):
        model = model.module

    if args.peft is not None:
        param_groups = get_optimizer_params_peft(args, model)
    else:
        param_groups = get_optimizer_params(args, model)

    # Use AdamW.
    optimizer = AdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay)
    print_rank(f'Optimizer = {optimizer.__class__.__name__}')
    return optimizer


def get_learning_rate_scheduler(args, optimizer):
    if args.total_iters is None:
        args.total_iters = args.train_iters_per_epoch * args.epochs
    if args.lr_decay_style == "constant":
        lr_scheduler = get_constant_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.warmup_iters)
    elif args.lr_decay_style == "cosine":
        lr_scheduler = CosineAnnealingLR(
            optimizer,
            T_max=args.total_iters,
            eta_min=args.lr_min)
    elif args.lr_decay_style == "noam":
        lr_scheduler = get_polynomial_decay_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.warmup_iters,
            num_training_steps=args.total_iters,
            power=0.5)
    else:
        raise ValueError(f"lr_scheduler of type {args.lr_decay_style} is not supported yet.")

    return lr_scheduler


def setup_model_and_optimizer(args, ds_config, device, set_optim=True):
    # get the model
    model = get_model(args, device)
    # get the optimizer and lr_scheduler
    if set_optim:
        optimizer = get_optimizer(args, model)
        lr_scheduler = get_learning_rate_scheduler(args, optimizer)
    else:
        optimizer, lr_scheduler = None, None
        
    model, optimizer, _, lr_scheduler = deepspeed.initialize(
        model=model,
        optimizer=optimizer,
        args=args,
        lr_scheduler=lr_scheduler,
        mpu=None,
        config_params=ds_config
    )
    
    # get the memory usage
    print_rank("Model mem\n", torch.cuda.memory_summary())
    return model, optimizer, lr_scheduler


def prepare_dataset(args, tokenizer):
    data = {}
    rng_sample = random.Random(args.seed)
    if args.do_train:
        data["train"] = LMTrainDataset(args, tokenizer, args.data_dir, "train", args.train_num, args.train_ratio, rng_sample)
        print_rank("train num", len(data["train"]))
        data["dev"] = LMTrainDataset(args, tokenizer, args.data_dir, "valid", args.dev_num, args.dev_ratio, rng_sample)
    elif args.do_eval:
        data["test"] = LMTrainDataset(args, tokenizer, args.data_dir, "valid", args.dev_num, args.dev_ratio, rng_sample)
    else:
        raise ValueError("Do train and do eval must set one")
        
    # pre-trained dataset
    if args.do_train and args.lm_data_dir is not None:
        data["pt_train"] = LMTrainDataset(args, tokenizer, args.lm_data_dir, "train", args.train_num, args.train_ratio, rng_sample)
        print_rank("train num", len(data["pt_train"]))
    return data


def pt_loss(args, model, model_batch, no_model_batch):
    loss_mask = (no_model_batch["label"] != -100).int()
    outputs = model(**model_batch, return_dict=True, use_cache=False)
    logits = outputs.logits
    loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
    lm_loss = loss_fn(logits.view(-1, logits.size(-1)), no_model_batch["label"].view(-1))
    return lm_loss


def get_distil_loss(args, tokenizer, model, teacher_model, model_batch, no_model_batch, logits):
    with torch.no_grad():
        teacher_model.eval()
        teacher_outputs = teacher_model(**model_batch, use_cache=False)
        teacher_logits = teacher_outputs.logits
    if args.model_parallel:
        raise NotImplementedError
    else:
        if "sfkl" in args.type:
            distil_loss = skewed_forward_kl(logits, teacher_logits, no_model_batch, lam=args.skew_alpha)
        elif "srkl" in args.type:
            distil_loss = skewed_reverse_kl(logits, teacher_logits, no_model_batch, lam=args.skew_alpha)
        elif "jsd" in args.type:
            distil_loss = js_distance(logits, teacher_logits, no_model_batch)
        elif "tvd" in args.type:
            distil_loss = tv_distance(logits, teacher_logits, no_model_batch)
        elif "fkl" in args.type or args.type == "kd":
            distil_loss = forward_kl(logits, teacher_logits, no_model_batch)
        elif "rkl" in args.type:
            distil_loss = reverse_kl(logits, teacher_logits, no_model_batch)
        else:
            raise NotImplementedError
    return distil_loss


def get_teacher_lm_loss(args, tokenizer, model, teacher_model, model_batch):
    with torch.no_grad():
        t_gen_out = teacher_model.generate(
            **model_batch,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            max_length=args.max_length,
            top_k=0,
            top_p=1,
            temperature=1.0,
            do_sample=True,
            return_dict_in_generate=True,
            output_scores=False)
    
    full_ids = t_gen_out.sequences
    
    input_ids = full_ids[:, :-1]
    mask = (input_ids != tokenizer.pad_token_id).long()
    labels = full_ids[:, 1:]    
    labels = torch.masked_fill(labels, mask==0, -100)
    labels[:, :model_batch["input_ids"].size(1)-1] = -100
    loss_mask = (labels != -100).float()
    
    new_batch = {
        "input_ids": input_ids,
        "attention_mask": mask,
    }
    
    if args.model_type in ["gpt2"]:
        position_ids = torch.cumsum(mask, dim=-1) - 1
        position_ids = torch.masked_fill(position_ids, mask==0, 0)    
        new_batch["position_ids"] = position_ids    
    
    loss_fn = nn.CrossEntropyLoss(ignore_index=-100)

    outputs = model(**new_batch, return_dict=True, use_cache=False)
    logits = outputs.logits
    lm_loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))

    return lm_loss


def finetune(args, tokenizer: AutoTokenizer, model: deepspeed.DeepSpeedEngine, optimizer: AdamW, lr_scheduler, dataset, device, teacher_model=None):
    print_rank("Start Fine-tuning")

    # print_inspect(model, '*')
    if args.model_parallel:
        raise NotImplementedError
    else:
        dp_world_size = dist.get_world_size()
        dp_rank = dist.get_rank()
        dp_group = None
        loss_func = nn.CrossEntropyLoss()

    sampler = DistributedSampler(dataset["train"], shuffle=True, drop_last=True, rank=dp_rank, num_replicas=dp_world_size)
    train_dataloader = DataLoader(
        dataset['train'], sampler=sampler, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=dataset["train"].collate)
    
    if "pt_train" in dataset:
        pt_sampler = DistributedSampler(dataset["pt_train"], shuffle=True, drop_last=True, rank=dp_rank, num_replicas=dp_world_size)
        pt_train_dataloader = DataLoader(
        dataset['pt_train'], sampler=pt_sampler, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=dataset["pt_train"].collate)
        pt_train_iter = iter(pt_train_dataloader)
        
    student_generator = SampleGenerator(args, tokenizer)

    step, global_step = 1, 1
    total_loss, total_distil_loss, total_time = 0.0, 0.0, 0.0
    
    adaptive_threshold = args.init_threshold if "adaptive" in args.type else None
    prev_avg_loss = evaluate(args, tokenizer, model, dataset["dev"], "dev", 0, device, adaptive_threshold)
    replay_buffer = ReplayBuffer(args)
    
    for epoch in range(args.epochs):
        sampler.set_epoch(epoch)

        model.train()
        for it, (model_batch, no_model_batch, gen_data) in enumerate(train_dataloader):
            dataset["train"].move_to_device(model_batch, no_model_batch, gen_data, device)
            
            if args.lm_data_dir is not None:
                try:
                    pt_model_batch, pt_no_model_batch, pt_gen_data = next(pt_train_iter)
                    # pt_model_batch, pt_no_model_batch, pt_gen_data = pt_train_iter.next()
                except:
                    pt_train_iter = iter(pt_train_dataloader)
                    # pt_model_batch, pt_no_model_batch, pt_gen_data = pt_train_iter.next()
                    pt_model_batch, pt_no_model_batch, pt_gen_data = next(pt_train_iter)
                    
                dataset["pt_train"].move_to_device(pt_model_batch, pt_no_model_batch, pt_gen_data, device)
            
            torch.cuda.synchronize()
            st_time = time.time()
            
            # # sampling ratio:
            samp_threshold = adaptive_threshold * (1 - global_step / args.total_iters)
            if "adaptive" in args.type:
                if args.replay_ratio == "constant":
                    samp_threshold = adaptive_threshold * 0.5
                elif args.replay_ratio == "increasing":
                    samp_threshold = adaptive_threshold * global_step / args.total_iters
                else:
                    samp_threshold = adaptive_threshold * (1 - global_step / args.total_iters)
            
            # data generation
            if args.student_gen:
                r = np.random.uniform(0, 1)
                if "mixed" in args.type and r < args.mixed_alpha:
                    model_batch = student_generator.run_sample(model, gen_data)
                    no_model_batch["label"] = model_batch.pop("no_model_batch")
                    
                    replay_buffer.move_to_memory(model_batch, no_model_batch)
                    model_batch, no_model_batch = replay_buffer.sample()
                    model_batch, no_model_batch = replay_buffer.move_to_device(model_batch, no_model_batch, device)
                    
                elif "adaptive" in args.type and (r < samp_threshold or (r < adaptive_threshold and len(replay_buffer) < args.capacity)):

                    model_batch = student_generator.run_sample(model, gen_data)
                    no_model_batch["label"] = model_batch.pop("no_model_batch")
                    
                    if args.model_type in ["opt"]:
                        model_batch.pop('position_ids')
                        
                    replay_buffer.move_to_memory(model_batch, no_model_batch)
                    
                elif "adaptive" in args.type and r < adaptive_threshold:
                    model_batch, no_model_batch = replay_buffer.sample()
                    model_batch, no_model_batch = replay_buffer.move_to_device(model_batch, no_model_batch, device)
                    
                model.train()

            outputs = model(**model_batch, use_cache=False)
            
            logits = outputs.logits
            if args.model_parallel:
                raise NotImplementedError
            else:
                lm_loss = loss_func(logits.float().view(-1, logits.shape[-1]), no_model_batch["label"].view(-1))
            
            if teacher_model is not None:
                distil_loss = get_distil_loss(args, tokenizer, model, teacher_model, model_batch, no_model_batch, logits)
                loss = (1 - args.kd_ratio) * lm_loss + args.kd_ratio * distil_loss
            else:
                loss = lm_loss
                
            if args.lm_data_dir is not None:
                assert args.lm_coef is not None
                loss += args.lm_coef * pt_loss(args, model, pt_model_batch, pt_no_model_batch)
                
            model.backward(loss)
            model.step()
             
            dist.all_reduce(loss, dist.ReduceOp.SUM, group=dp_group)
            global_loss = loss.item() / dp_world_size

            global_distil_loss = 0
            if teacher_model is not None:
                dist.all_reduce(distil_loss, dist.ReduceOp.SUM, group=dp_group)
                global_distil_loss = distil_loss.item() / dp_world_size
                total_distil_loss += global_distil_loss
    
            torch.cuda.synchronize()
            elapsed_time = time.time() - st_time

            total_loss += global_loss
            total_time += elapsed_time

            # Logging
            def get_log(log_loss, log_distil_loss, log_time):
                return "train | epoch {:3d} | Iter: {:6d}/{:6d} | global iter: {:6d}/{:6d} | loss: {:.4f} | ds_loss: {:.4f} | lr: {:.4e} | scale: {:10.4f} | micro time: {:.3f} | step time: {:.3f}".format(
                    epoch,
                    step,
                    args.total_iters * args.gradient_accumulation_steps,
                    global_step,
                    args.total_iters,
                    log_loss,
                    log_distil_loss,
                    lr_scheduler.get_last_lr()[0],
                    optimizer.cur_scale if hasattr(optimizer, "cur_scale") else 0,
                    elapsed_time,
                    log_time,
                )

            if args.mid_log_num > 0:
                mid_log_step = args.gradient_accumulation_steps // args.mid_log_num
                mid_log_step = 1 if mid_log_step == 0 else mid_log_step
                if step % mid_log_step == 0:
                    print_rank(get_log(global_loss, global_distil_loss, 0))

            if global_step % args.log_interval == 0 and step % args.gradient_accumulation_steps == 0:
                log_str = get_log(
                    total_loss / (args.log_interval * args.gradient_accumulation_steps),
                    total_distil_loss / (args.log_interval * args.gradient_accumulation_steps),
                    total_time / (args.log_interval))
                print_rank("*" * 100)
                print_rank(log_str)
                print_rank(args.save)
                print_rank("*" * 100)
                save_rank(log_str, os.path.join(args.save, "log.txt"))
                total_loss, total_distil_loss, total_time = 0.0, 0.0, 0.0
            
            # Checkpointing
            if args.save and args.save_interval and global_step % args.save_interval == 0 and step % args.gradient_accumulation_steps == 0:
                save_dir_path = os.path.join(args.save, str(global_step))
                if args.model_parallel:
                    raise NotImplementedError
                else:
                    if dist.get_rank() == 0:
                        os.makedirs(save_dir_path, exist_ok=True)
                        print_rank(f"Model save to {save_dir_path}")
                        tokenizer.save_pretrained(save_dir_path)
                        model.module.save_pretrained(save_dir_path, safe_serialization=False)
                dist.barrier()

            # Evaluation
            if args.eval_interval and global_step % args.eval_interval == 0 and step % args.gradient_accumulation_steps == 0:
                curr_avg_loss = evaluate(args, tokenizer, model, dataset["dev"], "dev", epoch, device, adaptive_threshold)
                if "adaptive" in args.type:
                    if curr_avg_loss >= prev_avg_loss + args.loss_eps:
                        adaptive_threshold += 0.1
                        adaptive_threshold = min(adaptive_threshold, 1.0)
                        prev_avg_loss = curr_avg_loss
                    
                model.train()
                
            step += 1
            if step % args.gradient_accumulation_steps == 0:
                global_step += 1
            
            if global_step > args.total_iters:
                break
            
    return model


def evaluate(args, tokenizer, model, dataset: LMTrainDataset, split, epoch, device, adaptive_threshold=None):
    
    collate_fn = dataset.collate

    if args.model_parallel:
        raise NotImplementedError
    else:
        dp_world_size = dist.get_world_size()
        dp_rank = dist.get_rank()
        dp_group = None
        loss_func = nn.CrossEntropyLoss()

    print_rank("dp size", dp_world_size)

    generation_config = GenerationConfig(
        do_sample=args.do_sample,
        top_p=args.top_p,
        top_k=args.top_k,
        temperature=args.temperature,
        repetition_penalty=args.repetition_penalty,
        max_length=args.max_length,
        min_length=None,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        return_dict_in_generate=True,
        output_scores=False
    )

    sampler = DistributedSampler(dataset, shuffle=False, drop_last=False, rank=dp_rank, num_replicas=dp_world_size)
    dataloader = DataLoader(
        dataset, sampler=sampler, batch_size=args.eval_batch_size, num_workers=args.num_workers, collate_fn=collate_fn)

    model.eval()
    all_loss = 0.0
    step = 0
    
    all_response_ids = []
    
    with torch.no_grad():
        for it, (model_batch, no_model_batch, gen_data) in enumerate(tqdm(dataloader, desc="Evaluating", disable=(dist.get_rank() != 0))):
            print_rank(f"{it}/{len(dataloader)}")
            dataset.move_to_device(model_batch, no_model_batch, gen_data, device)
            logits = model(**model_batch).logits
            if args.model_parallel:
                raise NotImplementedError
            else:
                loss = loss_func(logits.view(-1, logits.shape[-1]), no_model_batch["label"].view(-1))
            
            max_new_tokens = args.max_length - gen_data["input_ids"].size(1)
            
            if args.eval_gen:            
                gen_out = model.generate(
                    **gen_data,
                    generation_config=generation_config,
                    max_new_tokens=max_new_tokens)
                
                full_ids = gen_out.sequences
                
                full_ids = F.pad(
                    full_ids,
                    (0, args.max_length - full_ids.shape[1]),
                    value=tokenizer.pad_token_id,
                )
                
                response_ids = full_ids[:, gen_data["input_ids"].size(1):]
                all_response_ids.append(response_ids)
                    
            dist.all_reduce(loss, dist.ReduceOp.SUM, group=dp_group)
            loss = loss / dp_world_size
            all_loss += loss.item()
            step += 1
    
    if args.eval_gen:
        all_response_ids = torch.cat(all_response_ids, dim=0)
        all_response_ids = all_gather(all_response_ids, dim=1, world_size=dp_world_size, group=dp_group, op="stack")
        all_response_ids = all_response_ids.view(-1, all_response_ids.size(-1))
        
        responses = tokenizer.batch_decode(all_response_ids, skip_special_tokens=True)
    
    if get_rank() == 0:
        if args.eval_gen:
            references = dataset.answers
            responses = responses[:len(references)]
            
            res = compute_metrics(responses, references)
        
            eval_dir = os.path.join(args.save, "eval", str(epoch))
            print_rank(eval_dir)
            os.makedirs(eval_dir, exist_ok=True)
            with open(os.path.join(eval_dir, "answers.jsonl"), "w") as f:
                for resp in responses:
                    f.write(json.dumps({"text": resp}) + "\n")
        else:
            res = {}
    
        avg_loss = all_loss / step
        
        if "adaptive" in args.type:
            log_str = f"{split} | avg_loss: {avg_loss} | {res} | threshold: {adaptive_threshold}"
        else:
            log_str = f"{split} | avg_loss: {avg_loss} | {res}"
        print_rank(log_str)
        save_rank(log_str, os.path.join(args.save, "log.txt"))
        
    return all_loss / step


def main():
    torch.backends.cudnn.enabled = False
    
    args = get_args()
    initialize(args)
    
    if dist.get_rank() == 0:
        print_args(args)
        with open(os.path.join(args.save, "args.json"), "w") as f:
            json.dump(vars(args), f)
    
    device = torch.cuda.current_device()
    cur_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
    save_rank("\n\n" + "="*30 + f" EXP at {cur_time} " + "="*30, os.path.join(args.save, "log.txt"))
    
    with open(args.deepspeed_config, "r") as f:
        ds_config = json.load(f)

    ds_config["gradient_accumulation_steps"] = args.gradient_accumulation_steps
    ds_config["train_micro_batch_size_per_gpu"] = args.batch_size
    ds_config["gradient_clipping"] = args.clip_grad
    ds_config["steps_per_print"] = 10000000
    
    if not args.do_train:
        ds_config["zero_optimization"]["stage"] = 0
    
    args.fp32 = not ds_config["fp16"]["enabled"]    
    args.deepspeed_config = None
    
    # get the tokenizer
    tokenizer = get_tokenizer(args)
    dataset = prepare_dataset(
        args,
        tokenizer,
    )
    
    dp_world_size = dist.get_world_size()
    
    if args.do_train:
        args.train_iters_per_epoch = int(len(dataset["train"]) / (args.batch_size * dp_world_size * args.gradient_accumulation_steps))
        print_rank("Train iters per epoch", args.train_iters_per_epoch)
        if args.total_iters is None:
            args.total_iters = args.train_iters_per_epoch * args.epochs
        if args.epochs is None:
            args.epochs = math.ceil(args.total_iters / args.train_iters_per_epoch)
        print_rank("total_iters", args.total_iters)
        
        if args.save_interval == -1:
            args.save_interval = args.train_iters_per_epoch
        
        if args.eval_interval == -1:
            args.eval_interval = args.train_iters_per_epoch
    
    model, optimizer, lr_scheduler = setup_model_and_optimizer(args, ds_config, device, set_optim=args.do_train)
    
    if args.teacher_model_type is None:
        args.teacher_model_type = args.model_type
    
    if args.teacher_model_path is not None:
        teacher_model = get_teacher_model(args, device)
    else:
        teacher_model = None
    
    if args.do_train:
        model = finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset, device, teacher_model=teacher_model)
   
    if args.do_eval:
        evaluate(args, tokenizer, model, dataset["test"], "test", 0, device)
        
    
if __name__ == "__main__":
    main()

================================================
FILE: generate.py
================================================
import time
import os

import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
import deepspeed
import numpy as np

import json
from tqdm import tqdm

from transformers import mpu

from arguments import get_args

from data_utils.prompt_datasets import PromptDataset
from utils import print_args, initialize
from utils import print_rank, get_rank
from utils import save_rank
from utils import all_gather
from utils import get_tokenizer, get_model


torch.set_num_threads(4)


def setup_model(args, ds_config, device):
    # get the model
    model = get_model(args, device)
    # get the optimizer and lr_scheduler
    optimizer, lr_scheduler = None, None
        
    model, _, _, _ = deepspeed.initialize(
        model=model,
        optimizer=optimizer,
        args=args,
        lr_scheduler=lr_scheduler,
        mpu=mpu if args.model_parallel else None,
        config_params=ds_config
    )
    
    # get the memory usage
    print_rank("Model mem\n", torch.cuda.memory_summary())
    return model


def prepare_dataset(args, tokenizer):
    data = {}
    data = PromptDataset(args, tokenizer, "train", data_path=args.data_dir, num=args.gen_num)
    print_rank("gen num", len(data))
    return data


def generate(args, tokenizer, model, dataset, device):
    
    collate_fn = dataset.collate

    if args.model_parallel:
        dp_world_size = mpu.get_data_parallel_world_size()
        dp_rank = mpu.get_data_parallel_rank()
        dp_group = mpu.get_data_parallel_group()
    else:
        dp_world_size = dist.get_world_size()
        dp_rank = dist.get_rank()
        dp_group = None

    sampler = DistributedSampler(dataset, shuffle=False, drop_last=False, rank=dp_rank, num_replicas=dp_world_size)
    dataloader = DataLoader(
        dataset, sampler=sampler, batch_size=args.eval_batch_size, num_workers=args.num_workers, collate_fn=collate_fn)

    model.eval()
    all_gen_ids = []
    all_idxs = []
    max_new_tokens = args.max_length - args.max_prompt_length

    with torch.no_grad():
        for it, (model_batch, no_model_batch) in enumerate(tqdm(dataloader, desc="Generating", disable=(dist.get_rank() != 0))):
            dataset.move_to_device(model_batch, no_model_batch, device)
            t_gen_out = model.generate(
                **model_batch,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
                max_new_tokens=max_new_tokens,
                top_k=args.top_k,
                top_p=args.top_p,
                temperature=args.temperature,
                do_sample=True,
                return_dict_in_generate=True,
                output_scores=False)
    
            full_ids = t_gen_out.sequences
            gen_ids = full_ids[:, model_batch["input_ids"].size(1):]
            buffer = torch.ones(gen_ids.size(0), max_new_tokens, dtype=torch.long, device=gen_ids.device) * tokenizer.pad_token_id
            buffer[:, :gen_ids.size(1)] = gen_ids
            all_gen_ids.append(buffer)
            all_idxs.append(no_model_batch["idx"])            

    all_idxs = all_gather(torch.cat(all_idxs, dim=0), dim=0, world_size=dp_world_size, group=dp_group).cpu().tolist()
    all_gen_ids = all_gather(torch.cat(all_gen_ids, dim=0), dim=0, world_size=dp_world_size, group=dp_group).cpu().tolist()
    
    if get_rank() == 0:
        all_gen_strs = tokenizer.batch_decode(all_gen_ids, skip_special_tokens=True)
        mean_lens = np.mean([len(tokenizer.encode(x)) for x in all_gen_strs[:100]])
        
        log_str = f"gen | avg. lens: {mean_lens}"
        print_rank(log_str)
        save_rank(log_str, os.path.join(args.save, "log.txt"))
        
        assert len(all_idxs) == len(all_gen_strs)

        for idx, g in zip(all_idxs, all_gen_strs):
            dataset.origin_data[idx]["gen_answer"] = g
        
        with open(os.path.join(args.save, "raw.jsonl"), "w") as f:
            for d in dataset.origin_data:
                if "gen_answer" in d:
                    f.write(json.dumps(d) + "\n")

    dist.barrier()


def main():
    torch.backends.cudnn.enabled = False
    
    args = get_args()
    initialize(args)
    
    if dist.get_rank() == 0:
        print_args(args)
        with open(os.path.join(args.save, "args.json"), "w") as f:
            json.dump(vars(args), f)
    
    device = torch.cuda.current_device()
    cur_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
    save_rank("\n\n" + "="*30 + f" EXP at {cur_time} " + "="*30, os.path.join(args.save, "log.txt"))
    
    with open(args.deepspeed_config, "r") as f:
        ds_config = json.load(f)

    ds_config["steps_per_print"] = args.gradient_accumulation_steps
    ds_config["zero_optimization"]["stage"] = 0

    args.fp32 = not ds_config["fp16"]["enabled"]
    args.deepspeed_config = None
    
    # get the tokenizer
    tokenizer = get_tokenizer(args)
    dataset = prepare_dataset(
        args,
        tokenizer,
    )
    
    model = setup_model(args, ds_config, device)
    
    generate(args, tokenizer, model, dataset, device)


if __name__ == "__main__":
    main()


================================================
FILE: install.sh
================================================
export NCCL_DEBUG=""
# conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -c nvidia
# pip install transformers==4.42.4
pip install vllm==0.5.0
pip install deepspeed
pip install nltk
pip install numerize
pip install rouge-score
pip install torchtyping
pip install rich
pip install accelerate
pip install datasets
pip install sentencepiece
pip install protobuf
pip install peft

================================================
FILE: minillm/__init__.py
================================================
from deepspeed import DeepSpeedConfig
from typing import Optional

# from trlx.utils.loading import get_orchestrator, get_pipeline, get_trainer
from .sampler import PPOSampler
from .pipelines import PPOPipeline, LMPipeline
from .trainer import PPOTrainer
from .reward import Reward

def train(
    args,
    tokenizer,
    reward_fn = None,
    teacher_model=None,
    prompt_data: Optional[str] = None,
    eval_prompt_data: Optional[str] = None,
    lm_data: Optional[str] = None,
    eval_lm_data: Optional[str] = None,
    ds_config: Optional[DeepSpeedConfig] = None,
):

    trainer = PPOTrainer(
        args=args,
        tokenizer=tokenizer,
        reward_fn=reward_fn,
        ds_config=ds_config,
    )
    trainer.set_teacher_model(teacher_model)

    ppo_pipeline = PPOPipeline(
        args, tokenizer, "train", prompt_data, num=args.train_num
    )

    sampler = PPOSampler(
        args, trainer, ppo_pipeline, chunk_size=args.chunk_size
    )
    sampler.run_sample(args.num_rollouts_per_device)
    
    eval_ppo_pipeline = PPOPipeline(
        args, trainer.tokenizer, "valid", eval_prompt_data, fix_prompts=True, num=args.dev_num
    )
    trainer.add_eval_pipeline(eval_ppo_pipeline)

    lm_pipeline = LMPipeline(
        args, trainer.tokenizer, "train", lm_data, num=args.train_num) if lm_data is not None else None
    eval_lm_pipeline = LMPipeline(
        args, trainer.tokenizer, "valid", eval_lm_data, num=args.dev_num) if eval_lm_data is not None else None

    trainer.add_lm_pipeline(lm_pipeline, eval_lm_pipeline)

    trainer.train()
    return trainer


================================================
FILE: minillm/data_types.py
================================================
from dataclasses import dataclass
from typing import Iterable
from torchtyping import TensorType


@dataclass
class PromptElement:
    """
    Dataclass for a single prompt, containing its string and tokenized form.

    :param text: The prompt text.
    :type text: str

    :param tokens: The prompt tokens. Should be a long tensor
    :type tokens: torch.Tensor
    """

    text: str
    tokens: TensorType["num_tokens"]


@dataclass
class PromptBatch:
    """
    Batched PromptElement

    :param text: An iterable of prompt texts.
    :type text: Iterable[str]

    :param tokens: A long tensor batch of prompt tokens.
    :type tokens: torch.Tensor
    """

    text: Iterable[str]
    tokens: TensorType["batch_size", "num_tokens"]


@dataclass
class PPORLElement:
    """
    :param query_tensor: The query tensor i.e. the prompt tokens.
                         Should be a long tensor.
    :type query_tensor: torch.Tensor

    :param response_tensor: The response tensor i.e. the output tokens.
                            Should be a long tensor.
    :type response_tensor: torch.Tensor

    :param logprobs: The log probabilities over all tokens in the vocabulary for
                    each token generated from the policy network
                    (i.e. the autoregressive model).
                    Should be a float tensor of same size as tokens,
                    with a dimension across the vocabulary.
    :type logprobs: torch.Tensor

    :param values: The values for each token generated from the value network or value head.
                    Should be a float tensor of same size as tokens.
    :type values: torch.Tensor

    :param rewards: The rewards for each token outputted in response.
                    Should be a float tensor of same size as tokens.
    :type rewards: torch.Tensor
    """

    query_tensor: TensorType["query_size"]
    response_tensor: TensorType["response_size"]
    lens: int
    s_lens: int
    mask: TensorType["response_size"]
    logprobs: TensorType["response_size"]
    rewards: TensorType["response_size"]
    rev_kl: TensorType["response_size"]
    w: TensorType["response_size"]
    inf_mask: TensorType["response_size", "vocab_size"]
    t_rewards: TensorType["response_size"]
    ent_rewards: TensorType["response_size"]


@dataclass
class PPORLBatch:
    """
    A batched version of the PPORLElement. See PPORLElement for more details on individual fields.

    :param query_tensors: A batch of query tensors. Should be a long tensor.
    :type query_tensors: torch.Tensor

    :param response_tensors: A batch of response tensors. Should be a long tensor.
    :type response_tensors: torch.Tensor

    :param logprobs: A batch of log probabilities from policy
    :type logprobs: torch.Tensor

    :param values: A batch of values from value network
    :type values: torch.Tensor

    :param rewards: A batch of rewards
    :type rewards: torch.Tensor
    """

    query_tensors: TensorType["batch_size", "query_size"]
    response_tensors: TensorType["batch_size", "response_size"]
    lens: TensorType["batch_size"]
    s_lens: TensorType["batch_size"]
    mask: TensorType["batch_size", "response_size"]
    logprobs: TensorType["batch_size", "response_size"]
    rewards: TensorType["batch_size", "response_size"]
    rev_kl: TensorType["batch_size", "response_size"]
    w: TensorType["batch_size", "response_size"]
    inf_mask: TensorType["batch_size", "response_size", "vocab_size"]
    t_rewards: TensorType["batch_size", "response_size"]
    ent_rewards: TensorType["batch_size", "response_size"]

================================================
FILE: minillm/losses.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
from torchtyping import TensorType

from .data_types import PPORLBatch
from .utils import whiten, get_entropy, get_x_entropy, get_log_probs

from transformers import mpu

from utils import all_gather, print_rank


class Loss():
    def __init__(self, args, trainer):
        self.args = args
        self.trainer = trainer

    def _get_cumsum_rewards(self, rewards):          
        full_rewards = torch.zeros_like(rewards[:, 0])
        for t in reversed(range(rewards.size(1))):
            full_rewards = self.args.gamma * full_rewards + rewards[:, t]
            
        return full_rewards

    def _get_advantages_and_returns(
        self,
        rewards: TensorType["batch_size", "response_size"],
        response_length: int,
        mask: TensorType["batch_size", "response_size"],
        use_whitening: Optional[bool] = True,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        last_rw = 0
        rw_reversed = []
        
        rewards = rewards.float()
        mask = mask.float()
        lens = torch.cumsum(mask, dim=-1)      # faster way        
        lens = mask - lens + lens[:, -1:None]  # faster way
        lens = torch.masked_fill(lens, lens==0, 1)

        for t in reversed(range(response_length)):
            rw_delta = rewards[:, t]
            last_rw = rw_delta + self.args.gamma * last_rw
            rw_reversed.append(last_rw)

        rw = torch.stack(rw_reversed[::-1], dim=1)
        rw = rw / lens

        advantages = rw

        if use_whitening:
            advantages = whiten(advantages)
        
        return advantages.detach()

    def _pg_loss(
        self,
        logprobs: TensorType["batch_size", "response_size"],
        old_logprobs: TensorType["batch_size", "response_size"],
        advantages: TensorType["batch_size", "response_size"],
        mask: TensorType["batch_size", "response_size"],
        w: TensorType["batch_size", "response_size"],
    ):
        """PPO objective function.
        References:
        - https://stable-baselines.readthedocs.io/en/master/modules/ppo2.html
        """
        n = mask.sum()
        
        log_ratio = (logprobs - old_logprobs) * mask
        ratio = torch.exp(log_ratio.float())            
        ratio = ratio * w

        if any(torch.isinf(advantages).view(-1)):
            print("[ERROR] advantage inf")
        
        if any(torch.isinf(ratio).view(-1)):
            print("[ERROR] ratio inf")

        if any(torch.isnan(advantages).view(-1)):
            print("[ERROR] advantage nan")
        
        if any(torch.isnan(ratio).view(-1)):
            print("[ERROR] ratio nan")
        
        pg_loss1 = -advantages * ratio
        pg_loss2 = -advantages * torch.clamp(
            ratio,
            1.0 - self.args.cliprange,
            1.0 + self.args.cliprange,
        )
        pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2).float() * mask) / n

        return pg_loss

    def _reg_loss(self, query_ids, response_ids, mask, logits, inf_mask, stats):
        with torch.no_grad():
            t_logits = self.trainer.compute_logits_and_log_probs(query_ids, response_ids, inf_mask, base="teacher", return_logprobs=False)
        
        loss_exp_ent = 0
        xent = get_x_entropy(logits, t_logits, inf_mask, mask, model_parallel=self.args.model_parallel)
        s_ent = get_entropy(logits, inf_mask, mask, model_parallel=self.args.model_parallel)
        loss_exp_ent = torch.sum((xent - s_ent) * mask) / mask.sum()
        stats["reg_loss"] = loss_exp_ent.item()
        
        return loss_exp_ent

    def get_input_batch(self, ppo_batch: PPORLBatch, pt_batch):
        query_tensors = ppo_batch.query_tensors
        response_tensors = ppo_batch.response_tensors
        ppo_input_batch = self.trainer.get_model_inputs(query_tensors, response_tensors)
        pt_input_batch, _ = pt_batch
        # merge batch
        assert len(ppo_input_batch) == len(pt_input_batch), list(ppo_input_batch.keys())
        input_batch = {}
        for k in ppo_input_batch:
            input_batch[k] = torch.cat([ppo_input_batch[k], pt_input_batch[k]], dim=0)
        return input_batch

    def ppo_loss(self, batch: PPORLBatch, logits):
        stats = {}
        query_tensors = batch.query_tensors
        response_tensors = batch.response_tensors
        lens = batch.lens
        s_lens = batch.s_lens
        mask = batch.mask
        old_logprobs = batch.logprobs
        old_rewards = batch.rewards
        rev_kl = batch.rev_kl
        w = batch.w
        inf_mask = batch.inf_mask
        
        response_length = response_tensors.shape[-1]

        start = query_tensors.size(1) - 1 # "-1" for the first generated token AS TARGET
        end = query_tensors.size(1) + response_tensors.size(1) - 1 # "remove the last token that does not have target"

        logits = logits / self.args.temperature
        logits = logits[:, start:end]
        if inf_mask is not None:
            logits = logits.masked_fill(inf_mask, -float("inf"))
            
        tokens = torch.cat((query_tensors, response_tensors), dim=1)[
            :, -self.trainer.max_length :
        ]
        mask = self.trainer.get_mask(tokens)[:, start:end]
        
        logprobs = get_log_probs(logits, response_tensors, mask, inf_mask, model_parallel=self.args.model_parallel)

        advantages = self._get_advantages_and_returns(
            old_rewards, response_length, mask
        )
        
        loss = self._pg_loss(
            logprobs=logprobs,
            old_logprobs=old_logprobs,
            advantages=advantages,
            mask=mask,
            w=w,
        )
        stats["pg_loss"] = loss.item()
        
        single_step_reg_loss = self._reg_loss(query_tensors, response_tensors, mask, logits, inf_mask, stats)
        stats["reg_loss"] = single_step_reg_loss.item()
        
        if self.args.single_step_reg:
            loss += single_step_reg_loss
        
        stats["rl_loss"] = loss.item()
        
        with torch.no_grad():
            # generation values for reward
            cumsum_rewards = self._get_cumsum_rewards(old_rewards)
            rev_kl = torch.sum(rev_kl, dim=-1)
            
            if self.args.length_norm:
                cumsum_rewards = cumsum_rewards / lens
                rev_kl = rev_kl / s_lens
                        
            cumsum_rewards = all_gather(cumsum_rewards, dim=0, world_size=self.trainer.dp_world_size, group=self.trainer.dp_group).mean(dim=0).item()
            rev_kl = all_gather(rev_kl, dim=0, world_size=self.trainer.dp_world_size, group=self.trainer.dp_group).mean(dim=0).item()
            lens = all_gather(lens, dim=0, world_size=self.trainer.dp_world_size, group=self.trainer.dp_group).float().mean(dim=0).item()
            s_lens = all_gather(s_lens, dim=0, world_size=self.trainer.dp_world_size, group=self.trainer.dp_group).float().mean(dim=0).item()
        
        stats["reward"] = cumsum_rewards
        stats["rev_kl"] = rev_kl
        stats["mixed_lens"] = lens
        stats["stu_lens"] = s_lens
        
        return loss, stats

    def pt_loss(self, batch, logits):
        stats = {}
        model_batch, no_model_batch = batch
        loss_mask = (no_model_batch["label"] != -100).int()
        if self.args.model_parallel:
            lm_losses = mpu.parallel_cross_entropy(logits.contiguous().float(), no_model_batch["label"]).view(-1)
            lm_loss = (lm_losses * loss_mask.view(-1)).sum(-1) / loss_mask.view(-1).sum(-1)
        else:
            loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
            lm_loss = loss_fn(logits.view(-1, logits.size(-1)), no_model_batch["label"].view(-1))
        
        distil_loss = 0
        if self.trainer.teacher_model is not None and self.args.kd_ratio is not None:
            with torch.no_grad():
                teacher_outputs = self.trainer.teacher_model(**model_batch, return_dict=True, use_cache=False)
                teacher_logits = teacher_outputs.logits
            if self.args.model_parallel:
                distil_losses = mpu.parallel_soft_cross_entropy_loss(logits.float(), teacher_logits.float())
                distil_losses = distil_losses.view(-1)
                distil_loss = (distil_losses * loss_mask.view(-1)).sum(-1) / loss_mask.view(-1).sum(-1)
            else:
                teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
                inf_mask = torch.isinf(logits)
                logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
                prod_probs = torch.masked_fill(teacher_probs * logprobs, inf_mask, 0)
                x = torch.sum(prod_probs, dim=-1).view(-1)
                distil_loss = -torch.sum(x * loss_mask.view(-1), dim=0) / torch.sum(loss_mask.view(-1), dim=0)
            
            loss = (1-self.args.kd_ratio) * lm_loss + self.args.kd_ratio * distil_loss

        stats["pt_loss"] = loss.item()
        stats["lm_loss"] = lm_loss.item()
        stats["ds_loss"] = distil_loss.item()

        return loss, stats

================================================
FILE: minillm/model.py
================================================
import torch.nn as nn
from transformers import (
    AutoConfig,)

from utils import get_model


class PPOModel(nn.Module):
    def __init__(self, args, device):
        super().__init__()
        self.model_parallel = args.model_parallel
        self.config = AutoConfig.from_pretrained(args.model_path)
        self.base_model = get_model(args, device)
        self.base_model.eval() # no dropout for RL

    def forward(self, **x):
        base_model_outputs = self.base_model(**x)
        return base_model_outputs
    
    def generate(self, **x):
        return self.base_model.generate(**x)
    
    def set_force_gradient_checkpointing(self, value):
        self.base_model.set_force_gradient_checkpointing(value)


================================================
FILE: minillm/pipelines.py
================================================
import os
import json
import torch
import random
import numpy as np
from torch.utils.data import DataLoader, DistributedSampler
from transformers import mpu
import torch.distributed as dist

from data_utils.distributed_indexed import DistributedMMapIndexedDataset
from torch.distributed import get_rank, get_world_size
from utils import print_rank


class PPOPipeline():
    def __init__(self, args, tokenizer, split, ppo_data_path=None, fix_prompts=False, num=-1):
        super().__init__()
        self.tokenizer = tokenizer

        self.args = args
        self.tokenizer = tokenizer
        self.split = split
        self.pad_id = self.tokenizer.eos_token_id
        self.max_length = args.max_length
        self.rng_ppo = random.Random(args.seed_ppo)
        self.min_prompt_length = args.min_prompt_length
        self.max_prompt_length = args.max_prompt_length

        self.ppo_ctx = DistributedMMapIndexedDataset(ppo_data_path, f"{split}", get_rank(), get_world_size())
        self.ppo_raw, self.ppo_answers = None, None
        if os.path.exists(os.path.join(ppo_data_path, f"{split}.jsonl")):
            with open(os.path.join(ppo_data_path, f"{split}.jsonl")) as f:
                self.ppo_raw = [json.loads(line) for line in f.readlines()]
                self.ppo_answers = [x["output"] if isinstance(x["output"], list) else [x["output"]] for x in self.ppo_raw]

        self.num = min(num, len(self.ppo_ctx)) if num > 0 else len(self.ppo_ctx)
        self.fix_prompts = fix_prompts
        self.prompt_lengths = [None for _ in range(num)]
        print_rank(f"Num PPO instances: {len(self.ppo_ctx)}")
            
    def __len__(self):
        return self.num

    def __getitem__(self, index: int):
        data = self.ppo_ctx[index].astype(int)
        
        assert len(data) <= self.max_prompt_length
        
        if self.args.model_type!="qwen" and 65535 in data:
            source_len = np.where(data==65535)[0][0]
            prompt = data[:source_len]
            response = data[source_len+1:]
        else:
            prompt = data
            response = None
        
        # return prompt, rest
        return prompt, response
    
    def collate(self, samples):
        bs = len(samples)
        
        max_prompt_length = self.max_prompt_length
        
        model_batch = {
            "input_ids": torch.ones(bs, max_prompt_length, dtype=torch.long) * self.pad_id,
            "attention_mask": torch.zeros(bs, max_prompt_length, dtype=torch.long),
        }
        
        no_model_batch = {
            "full_ids": torch.ones(bs, self.max_length, dtype=torch.long) * self.pad_id,
            "full_attention_mask": torch.zeros(bs, self.max_length, dtype=torch.long),
            "full_label_ids": torch.ones(bs, self.max_length, dtype=torch.long) * -100,
        }
        
        for i, (prompt, response) in enumerate(samples):
            # left padding
            model_batch["input_ids"][i][-len(prompt):] = torch.tensor(prompt, dtype=torch.long)
            model_batch["attention_mask"][i][-len(prompt):] = 1
            if response is not None:
                full_ids = np.concatenate([prompt, response], axis=0)
                no_model_batch["full_ids"][i][:len(full_ids)-1] = torch.tensor(full_ids[:-1], dtype=torch.long)
                no_model_batch["full_attention_mask"][i][:len(full_ids)-1] = 1.0
                no_model_batch["full_label_ids"][i][len(prompt)-1:len(full_ids)-1] = torch.tensor(response, dtype=torch.long)
        
        return model_batch, no_model_batch

    def move_to_device(self, model_batch, no_model_batch, device):
        for k in model_batch:
            model_batch[k] = model_batch[k].to(device)        
        for k in no_model_batch:
            no_model_batch[k] = no_model_batch[k].to(device)    
        
        return model_batch, no_model_batch

    def create_loader(self, batch_size: int, shuffle=False, drop_last: bool = False, num_workers: int = 0) -> DataLoader:
        if self.args.model_parallel:
            dp_world_size = mpu.get_data_parallel_world_size()
            dp_rank = mpu.get_data_parallel_rank()
        else:
            dp_world_size = dist.get_world_size()
            dp_rank = dist.get_rank()
        
        sampler = DistributedSampler(self, shuffle=shuffle, drop_last=drop_last, rank=dp_rank, num_replicas=dp_world_size)
        return DataLoader(
            self, sampler=sampler, batch_size=batch_size, collate_fn=self.collate, num_workers=num_workers
        )


class LMPipeline():
    def __init__(self, args, tokenizer, split, lm_data_path=None, num=-1):
        super().__init__()
        self.tokenizer = tokenizer

        self.args = args
        self.tokenizer = tokenizer
        self.split = split
        self.pad_id = self.tokenizer.eos_token_id
        self.max_length = args.max_length
        self.rng_lm = random.Random(args.seed_lm)

        self.lm_ctx = DistributedMMapIndexedDataset(lm_data_path, f"{split}", get_rank(), get_world_size())
        self.num = min(num, len(self.lm_ctx)) if num > 0 else len(self.lm_ctx)
        print_rank(f"Num LM instances: {len(self.lm_ctx)}")
            
    def __len__(self):
        return self.num

    def __getitem__(self, index):
        return self._get_lm(index)

    def _get_lm(self, index):
        data = self.lm_ctx[index]
        input_ids = data.astype(int)
        return {
            "input_ids": input_ids[:self.max_length]
        }

    def _process_lm(self, i, samp, model_data, no_model_data):
        input_ids = samp["input_ids"]
        source_len = 1
        
        if self.args.model_type!="qwen" and 65535 in input_ids:
            source_len = np.where(input_ids==65535)[0][0]
            input_ids = np.concatenate([input_ids[:source_len], input_ids[source_len+1:]], axis=0)
        input_ids = input_ids[:self.max_length]
        input_len = len(input_ids)
        model_data["input_ids"][i][:input_len-1] = torch.tensor(input_ids[:-1], dtype=torch.long)
        model_data["attention_mask"][i][:input_len-1] = 1.0
        if self.args.model_type in ["gpt2"]:
            model_data["position_ids"][i][:input_len-1] = torch.arange(0, input_len-1, dtype=torch.long)
        no_model_data["label"][i][:input_len-1] = torch.tensor(input_ids[1:], dtype=torch.long)
        no_model_data["label"][i][:source_len-1] = -100
        no_model_data["loss_mask"][i][:input_len-1] = 1.0
        no_model_data["loss_mask"][i][:source_len-1] = 0

    def move_to_device(self, model_batch, no_model_batch, device):
        for k in model_batch:
            model_batch[k] = model_batch[k].to(device)

        for k in no_model_batch:
            no_model_batch[k] = no_model_batch[k].to(device)    
        
        return model_batch, no_model_batch

    def collate(self, samples):
        bs = len(samples)
        
        max_length = self.max_length
        
        model_data = {
            "input_ids": torch.ones(bs, max_length, dtype=torch.long) * self.pad_id,
            "attention_mask": torch.zeros(bs, max_length, dtype=torch.long)
        }

        if self.args.model_type in ["gpt2"]:
            model_data["position_ids"] = torch.zeros(bs, max_length, dtype=torch.long)

        no_model_data = {
            "label": torch.ones(bs, self.max_length, dtype=torch.long) * -100,
            "loss_mask": torch.zeros(bs, max_length)
        }
        
        for i, samp in enumerate(samples):        
            self._process_lm(i, samp, model_data, no_model_data)
            
        return model_data, no_model_data

    def create_loader(self, batch_size: int, shuffle=False, drop_last: bool = False, num_workers: int = 0) -> DataLoader:
        if self.args.model_parallel:
            dp_world_size = mpu.get_data_parallel_world_size()
            dp_rank = mpu.get_data_parallel_rank()
        else:
            dp_world_size = dist.get_world_size()
            dp_rank = dist.get_rank()
        
        sampler = DistributedSampler(self, shuffle=shuffle, drop_last=drop_last, rank=dp_rank, num_replicas=dp_world_size)
        return DataLoader(
            self, sampler=sampler, batch_size=batch_size, collate_fn=self.collate, num_workers=num_workers
        )


================================================
FILE: minillm/reward.py
================================================
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    mpu)


class Reward():
    def __init__(self, args, tokenizer: AutoTokenizer, model: AutoModelForCausalLM):
        self.args = args
        self.tokenizer = tokenizer
        self.model = model
        self.pad_token_id = tokenizer.pad_token_id
        self.eos_token_id = tokenizer.eos_token_id

    def get_input_batch(self, input_ids, gen_ids, output_pos=True):
        full_ids = torch.cat([input_ids, gen_ids], dim=-1)
        attention_mask = (full_ids != self.pad_token_id)

        model_inputs = {
            "input_ids": full_ids,
            "attention_mask": attention_mask,
            "use_cache": False
        }
        
        if (self.args.model_type in ["gpt2"]) and output_pos:
            position_ids = torch.cumsum(attention_mask, dim=-1) - 1
            position_ids.masked_fill_(~attention_mask, 0)
            model_inputs["position_ids"] = position_ids
        
        return model_inputs

    def reward_fn(self, input_ids, gen_ids, inf_mask=None, output_pos=True):
        # not include eos token
        
        self.model.eval()
        # input_ids = input_ids.repeat(1, 1)
        
        model_inputs = self.get_input_batch(input_ids, gen_ids, output_pos=output_pos)

        with torch.no_grad():
            outputs = self.model(**model_inputs)
        
        logits = outputs.logits # (B, L, V)
        if self.args.model_parallel:
            logits = logits - mpu.parallel_mean(logits.float(), dim=-1).unsqueeze(-1)
        else:
            logits = logits - torch.mean(logits, dim=-1, keepdim=True)
        
        mask = model_inputs["attention_mask"]
        logits = logits * mask.unsqueeze(-1) # set logits output by padding to 0
        
        logits = logits[:, input_ids.size(-1)-1:, :]
        mask = mask[:, input_ids.size(-1)-1:]

        if self.args.model_parallel:
            selection_value = mpu.parallel_gather(logits[:, :-1, :], -1, model_inputs["input_ids"][:, input_ids.size(-1):, None]).squeeze(-1)
        else:
            selection_value = torch.gather(logits[:, :-1, :], -1, model_inputs["input_ids"][:, input_ids.size(-1):, None]).squeeze(-1)

        current_logits = logits[:, :-1, :]
        if self.args.model_parallel:
            next_state_value = mpu.parallel_logsumexp(current_logits.float(), dim=-1)
        else:
            next_state_value = torch.logsumexp(current_logits, dim=-1)
        next_state_value = next_state_value * mask[:, :-1]
        raw_next_state_value = next_state_value

        scores = selection_value - next_state_value
        
        assert all((~torch.isinf(scores.view(-1))) & (~torch.isnan(scores.view(-1))))
        
        assert scores.size() == gen_ids.size()
        
        return {
            "rewards": scores,
            "inf_mask": inf_mask
        }


================================================
FILE: minillm/sampler.py
================================================
import torch
import os

from .data_types import PromptBatch, PPORLElement
from .pipelines import PPOPipeline
from .trainer import PPOTrainer

from utils import get_rank, print_rank, all_gather, save_rank
from .utils import get_rev_kl
from transformers import mpu

class PPOSampler():
    """
    Orchestrator prepares data for PPO training.
    Transforms samples from `pipeline` into `PPOBatch` and pushes them into trainer's `store`
    """

    def __init__(
        self,
        args,
        trainer: PPOTrainer,
        pipeline: PPOPipeline,
        chunk_size: int = 512,
    ):
        self.args = args
        self.pipeline = pipeline
        self.trainer = trainer
        self.chunk_size = chunk_size

        self.pipeline_loader = self.pipeline.create_loader(
            self.chunk_size, shuffle=True, drop_last=True, num_workers=self.args.num_workers
        )
        self.pipeline_iterator = iter(self.pipeline_loader)

        self.trainer.set_sampler(self)

        self.epochs = 0

    def run_sample(self, num_rollouts_per_device: int = 1024, iter_count: int = 0):
        """
        Takes `num_rollouts_per_device` prompts from `pipeline`, samples model and computes the
        KL againts a reference model. It then appends PPOElements to trainer's `store`
        """
        ppo_rl_elements = []

        while len(ppo_rl_elements) < num_rollouts_per_device:
            if ((not self.args.model_parallel) or mpu.get_model_parallel_rank()) == 0:
                print(f"Rank {get_rank()}: Number Sampling Elements {len(ppo_rl_elements)} / {num_rollouts_per_device}")
            try:
                batch: PromptBatch = next(self.pipeline_iterator)
            except StopIteration:
                self.epochs += 1
                print_rank(f"Another outer ppo epoch, outer ppo epoch: {self.epochs}")
                save_rank(f"Another outer ppo epoch, outer ppo epoch: {self.epochs}", os.path.join(self.args.save, "log.txt"))
                
                self.pipeline_loader.sampler.set_epoch(self.epochs)
                self.pipeline_iterator = iter(self.pipeline_loader)
                batch = next(self.pipeline_iterator)

            batch, no_model_batch = batch
            n = batch["input_ids"].size(0)
            
            batch, no_model_batch = self.pipeline.move_to_device(batch, no_model_batch, self.trainer.device)
            
            query_ids = batch["input_ids"]
            
            # generate and compute rollout scores
            with torch.no_grad():
                mode = "base"
                gen_out = self.trainer.generate(**batch, return_dict_in_generate=True, mode=mode, teacher_mixed_sample=(self.args.teacher_mixed_alpha is not None), output_scores=True)
                full_ids = gen_out.sequences
                response_ids = full_ids[:, query_ids.size(1):] # remove prompt (may include start token)
                mask = (full_ids != self.trainer.tokenizer.pad_token_id)[:, query_ids.size(-1)-1:query_ids.size(-1)+response_ids.size(-1)-1]
                lens = torch.sum(mask, dim=-1)
                gen_logits = gen_out.scores # NOTE: [b, s, h_p]
                inf_mask = torch.isinf(gen_logits)
                scores = self.trainer.reward_fn(query_ids, response_ids, inf_mask=inf_mask)
                t_rewards = scores["rewards"]
                inf_mask = scores["inf_mask"]
                _, rollout_logprobs = self.trainer.compute_logits_and_log_probs(query_ids, response_ids, inf_mask=inf_mask, base=mode)

                # student generation features
                if self.args.teacher_mixed_alpha is not None:
                    s_gen_out = self.trainer.generate(**batch, return_dict_in_generate=True, mode=mode, output_scores=True)
                    s_full_ids = s_gen_out.sequences
                    s_response_ids = s_full_ids[:, query_ids.size(1):]
                    s_inf_mask = torch.isinf(s_gen_out.scores)
                    s_response_ids = s_full_ids[:, query_ids.size(1):] # remove prompt (may include start token)
                    s_scores = self.trainer.reward_fn(query_ids, s_response_ids, inf_mask=s_inf_mask)
                    s_t_rewards = s_scores["rewards"]
                    s_inf_mask = s_scores["inf_mask"]
                    _, s_rollout_logprobs = self.trainer.compute_logits_and_log_probs(query_ids, s_response_ids, inf_mask=s_inf_mask, base=mode)
                    s_mask = (s_full_ids != self.trainer.tokenizer.pad_token_id)[:, query_ids.size(-1)-1:query_ids.size(-1)+s_response_ids.size(-1)-1]
                    s_lens = torch.sum(s_mask, dim=-1)
                else:
                    s_t_rewards = t_rewards
                    s_rollout_logprobs = rollout_logprobs
                    s_mask = mask
                    s_lens = lens

            rev_kl = get_rev_kl(s_t_rewards, s_rollout_logprobs, s_mask)

            if self.args.teacher_mixed_alpha is not None:
                with torch.no_grad():
                    _, t_rollout_logprobs = self.trainer.compute_logits_and_log_probs(query_ids, response_ids, inf_mask=inf_mask, base="teacher") # recompute because of the fp16 loss

            # get logprobs and the importance sampling weight w
            with torch.no_grad():
                if self.args.teacher_mixed_alpha is not None:
                    _, raw_logprobs = self.trainer.compute_logits_and_log_probs(query_ids, response_ids, inf_mask=inf_mask, base="base") # raw_logprobs: compute using the new model
                    logprobs = raw_logprobs
                    mix_probs = (1 - self.args.teacher_mixed_alpha) * torch.exp(rollout_logprobs.float()) + self.args.teacher_mixed_alpha * torch.exp(t_rollout_logprobs.float())
                    mix_logprobs = torch.log(mix_probs)
                    log_w = logprobs - mix_logprobs
                    w = torch.exp(log_w) # importance sampling weight
                else:
                    raw_logprobs = rollout_logprobs
                    logprobs = rollout_logprobs
                    w = torch.ones_like(logprobs)
                        
                # get ent_rewards
                ent_rewards = -logprobs

            rewards = t_rewards + ent_rewards

            if self.args.reward_scaling is not None:
                rewards = rewards / self.args.reward_scaling

            clip_reward = self.args.cliprange_reward
            if clip_reward:
                rewards = torch.clip(rewards, -clip_reward, clip_reward)

            query_ids = query_ids.cpu()
            response_ids = response_ids.cpu()
            lens = lens.cpu()
            s_lens = s_lens.cpu()
            mask = mask.cpu()
            logprobs = logprobs.cpu()
            rewards = rewards.cpu()
            rev_kl = rev_kl.cpu()
            w = w.cpu()
            inf_mask = inf_mask.cpu()
            
            new_ppo_rl_elements = [
                PPORLElement(
                    query_tensor=query_ids[i],
                    response_tensor=response_ids[i],
                    lens=lens[i],
                    s_lens=s_lens[i],
                    mask=mask[i],
                    logprobs=logprobs[i],
                    rewards=rewards[i],
                    rev_kl=rev_kl[i],
                    w=w[i],
                    inf_mask=inf_mask[i],
                    t_rewards=t_rewards[i],
                    ent_rewards=ent_rewards[i]
                )
                for i in range(n)
            ]
            ppo_rl_elements.extend(new_ppo_rl_elements)

        ppo_rl_elements = ppo_rl_elements[:num_rollouts_per_device]
        # Push samples and rewards to trainer's rollout storage
        self.trainer.push_to_store(ppo_rl_elements)
        
        if self.args.save_rollout:
            all_query_ids = all_gather(torch.stack([e.query_tensor for e in ppo_rl_elements], dim=0).to(self.trainer.device))
            all_response_ids = all_gather(torch.stack([e.response_tensor for e in ppo_rl_elements], dim=0).to(self.trainer.device))
            all_entropy = all_gather(torch.stack([e.entropy for e in ppo_rl_elements], dim=0).to(self.trainer.device))
            rollout_save_path = os.path.join(self.args.save, "rollout_history", str(iter_count))
            if get_rank() == 0:
                os.makedirs(rollout_save_path, exist_ok=True)
                torch.save((all_query_ids, all_response_ids, all_entropy), os.path.join(rollout_save_path, "all.pt"))


================================================
FILE: minillm/storages.py
================================================
import json
import os
import time
from abc import abstractmethod
from typing import Any, Callable, Iterable

import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
import torch.distributed as dist

from .data_types import PPORLElement, PPORLBatch

from utils import get_rank


class BaseRolloutStore(Dataset):
    def __init__(self, capacity=-1):
        self.history: Iterable[Any] = None
        self.capacity = capacity

    @abstractmethod
    def push(self, exps: Iterable[Any]):
        """
        Push experiences to rollout storage
        """
        pass

    def __getitem__(self, index: int) -> PPORLElement:
        return self.history[index]

    def __len__(self) -> int:
        return len(self.history)

    @abstractmethod
    def create_loader(
        self,
        batch_size: int,
        shuffle: bool,
        prep_fn: Callable = None,
        num_workers: int = 0,
        drop_last: bool = False
    ) -> DataLoader:
        """
        Create a dataloader for the rollout store

        :param prep_fn: Applied to RLElement after collation (typically tokenizer)
        :type prep_fn: Callable
        """
        pass
    
    @abstractmethod
    def broadcast(self, batch, src=0, group=None):
        pass
    
    @abstractmethod
    def move_to_device(self, batch, device):
        pass


class PPORolloutStorage(BaseRolloutStore):
    """
    Rollout storage for training PPO
    """

    def __init__(self, pad_token_id, seed):
        super().__init__()

        self.pad_token_id = pad_token_id
        self.history: Iterable[PPORLElement] = [None]
        self.rng = torch.Generator()
        self.rng.manual_seed(seed)

    def push(self, exps: Iterable[PPORLElement]):
        self.history += exps

    def save(self, path):
        def exp_to_dict(exp):
            return {k: v for k, v in exp.__dict__.items()}

        data = [exp_to_dict(exp) for exp in self.history]
        
        torch.save(data, os.path.join(path, f"{get_rank()}.pkl"))
            
    def load(self, path):
        data = torch.load(os.path.join(path, f"history_{get_rank()}.pkl"), map_location="cpu")
        self.history = [PPORLElement(**d) for d in data]

    def clear_history(self):
        self.history = []

    def export_history(self, location: str):
        assert os.path.exists(location)

        fpath = os.path.join(location, f"epoch-{str(time.time())}.json")

        def exp_to_dict(exp):
            return {k: v.cpu().tolist() for k, v in exp.__dict__.items()}

        data = [exp_to_dict(exp) for exp in self.history]
        with open(fpath, "w") as f:
            f.write(json.dumps(data, indent=2))

    def __getitem__(self, index: int) -> PPORLElement:
        return self.history[index]

    def __len__(self) -> int:
        return len(self.history)

    def collate(self, elems: Iterable[PPORLElement]):
        if any([e is None for e in elems]):
            print(elems)
        return PPORLBatch(
            # Left padding of already left-padded queries
            pad_sequence(
                [elem.query_tensor.flip(0) for elem in elems],
                padding_value=self.pad_token_id,
                batch_first=True,
            ).flip(1),
            # Right pad the rest, to have a single horizontal query/response split
            pad_sequence(
                [elem.response_tensor for elem in elems],
                padding_value=self.pad_token_id,
                batch_first=True,
            ),
            torch.tensor([elem.lens for elem in elems], dtype=torch.long),
            torch.tensor([elem.s_lens for elem in elems], dtype=torch.long),
            pad_sequence(
                [elem.mask for elem in elems],
                padding_value=0.0,
                batch_first=True,
            ),            
            pad_sequence(
                [elem.logprobs for elem in elems],
                padding_value=0.0,
                batch_first=True,
            ),
            pad_sequence(
                [elem.rewards for elem in elems],
                padding_value=0.0,
                batch_first=True,
            ),
            pad_sequence(
                [elem.rev_kl for elem in elems],
                padding_value=0.0,
                batch_first=True,
            ),
            pad_sequence(
                [elem.w for elem in elems],
                padding_value=0.0,
                batch_first=True,
            ),
            pad_sequence(
                [elem.inf_mask for elem in elems],
                padding_value=0,
                batch_first=True,
            ),
            pad_sequence(
                [elem.t_rewards for elem in elems],
                padding_value=0.0,
                batch_first=True,
            ),
            pad_sequence(
                [elem.ent_rewards for elem in elems],
                padding_value=0.0,
                batch_first=True,
            ),
        )

    def create_loader(self, batch_size: int, shuffle=False, drop_last: bool = False, num_workers: int = 0) -> DataLoader:
        # sampler = DistributedSampler(self, shuffle=shuffle, drop_last=drop_last)
        # we don't use distributed sampler because the dataset on each device is different
        return DataLoader(
            self, batch_size=batch_size, collate_fn=self.collate, num_workers=num_workers, shuffle=shuffle, drop_last=drop_last, generator=self.rng
        )
        
    def broadcast(self, batch: PPORLBatch, src=0, group=None):
        for k, v in batch.__dict__.items():
            dist.broadcast(batch.__dict__[k], src=src, group=group)
            
    def move_to_device(self, batch: PPORLBatch, device):
        for k, v in batch.__dict__.items():
            batch.__dict__[k] = batch.__dict__[k].to(device)

================================================
FILE: minillm/trainer.py
================================================
import json
import os
import deepspeed
from time import time
from typing import Optional, Tuple
from collections import defaultdict

import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.optim import AdamW
from rich.console import Console
from rich.table import Table
from tqdm import tqdm
from transformers import (
    AutoTokenizer,
    GenerationConfig,
    mpu)

from transformers import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup

from .utils import (
    get_scheduler_class,
    get_log_probs,
    get_rev_kl,
    significant
)

from .model import (
    PPOModel
)

from .pipelines import PPOPipeline, LMPipeline


from .storages import PPORolloutStorage
from .losses import Loss

from utils import print_rank, save_rank, get_rank, all_gather, save_parallel
from rouge_metric import compute_metrics


class PPOTrainer():
    """
    RL model trainer with an `accelerate` based backend
    """

    def __init__(self, args, tokenizer: AutoTokenizer, reward_fn, ds_config):
        self.args = args
        self.max_length = args.max_length
        self.ds_config = ds_config
        self.reward_fn = reward_fn
        self.device = torch.cuda.current_device()

        if int(os.environ.get("WORLD_SIZE", 1)) > 1:
            dist.barrier(device_ids=[int(os.environ.get("LOCAL_RANK", 0))])

        if args.model_parallel:
            raise NotImplementedError
        else:
            self.dp_world_size = dist.get_world_size()
            self.dp_rank = dist.get_rank()
            self.dp_group = None

        self.model = PPOModel(args, self.device)
        if args.model_parallel:
            raise NotImplementedError
        else:
            if dist.get_rank() == 0:
                print(' > number of parameters: {}M'.format(
                    int(sum([p.nelement() for p in self.model.parameters()]) / 1e6)), flush=True)

        self.sampler = None
        self.teacher_model = None
        self.opt = self.setup_optimizer()
        self.scheduler = self.setup_scheduler()
        self.model, self.opt, self.scheduler = self.setup_ds(self.model, self.opt, self.scheduler)
        
        self.tokenizer = tokenizer
        self.store = PPORolloutStorage(self.tokenizer.pad_token_id, self.args.seed_ppo + self.dp_rank)
        self.store.clear_history()
        
        self.losses = Loss(args, self)
        self.generate_kwargs = dict(
            do_sample=args.do_sample,
            top_p=args.top_p,
            top_k=args.top_k,
            temperature=args.temperature,
            max_length=args.max_length,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.pad_token_id,
        )

    def set_teacher_model(self, model):
        self.teacher_model = model

    def set_sampler(self, sampler):
        self.sampler = sampler

    def setup_optimizer(self):
        """
        Returns an optimizer derived from an instance's TRLConfig
        """
        optimizer = AdamW(
            self.model.parameters(),
            lr=self.args.lr,
            betas=[0.9, 0.95],
            eps=1.0e-8,
            weight_decay=1.0e-6
        )

        return optimizer

    def setup_scheduler(self):
        """
        Returns a learning rate scheduler derived from an instance's TRLConfig
        """
        if self.args.scheduler_name == "constant_trm":
            scheduler = get_constant_schedule_with_warmup(self.opt, num_warmup_steps=self.args.warmup_iters)
        elif self.args.scheduler_name == "cosine_trm":
            scheduler = get_cosine_schedule_with_warmup(self.opt, num_warmup_steps=self.args.warmup_iters, num_training_steps=self.args.total_iters)
        else:
            scheduler_class = get_scheduler_class(self.args.scheduler_name)
            scheduler = scheduler_class(self.opt, eta_min=self.args.lr_min, T_max=self.args.total_iters)
        
        return scheduler

    def setup_ds(self, model, optimizer=None, scheduler=None):
        model, optimizer, _, scheduler = deepspeed.initialize(
            model=model,
            optimizer=optimizer,
            args=self.args,
            lr_scheduler=scheduler,
            mpu=mpu if self.args.model_parallel else None,
            config_params=self.ds_config
        )
        return model, optimizer, scheduler

    def add_eval_pipeline(self, eval_pipeline: PPOPipeline):
        """Adds pipeline from with validation prompts"""
        self.eval_pipeline = eval_pipeline

    def add_lm_pipeline(self, lm_pipeline: LMPipeline, eval_lm_pipeline: LMPipeline):
        self.lm_pipeline = lm_pipeline
        self.eval_lm_pipeline = eval_lm_pipeline

    def get_model_inputs(
        self,
        query_tensors,
        response_tensors,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        tokens = torch.cat((query_tensors, response_tensors), dim=1)[
            :, -self.max_length :
        ]
        attention_mask = self.get_mask(tokens)
  
        batch = {
            "input_ids": tokens,
            "attention_mask": attention_mask
        }
        
        if self.args.model_type in ["gpt2"]:  
            # For a proper positional encoding in case of left padding
            position_ids = attention_mask.cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask.eq(0), 0)
            batch["position_ids"] = position_ids
        
        return batch

    def get_mask(self, tokens):
        attention_mask = (
            tokens.not_equal(self.tokenizer.pad_token_id).long()
        )
        return attention_mask

    def forward_model(self, batch):
        outputs = self.model(
            **batch,
            return_dict=True,
            use_cache=False,
        )
        return outputs

    def compute_logits_and_log_probs(self, query_ids, response_ids, inf_mask=None, base="base", return_logprobs=True):
        batch = self.get_model_inputs(
            query_ids, response_ids
        )
        
        if base == "base":
            model_cls = self.model.module.forward
        elif base == "teacher":
            model_cls = self.teacher_model
        else:
            raise NotImplementedError

        outputs = model_cls(
            **batch,
            return_dict=True,
            use_cache=False
        )

        logits = outputs.logits
        logits = logits / self.args.temperature

        start = query_ids.size(1) - 1
        end = query_ids.size(1) + response_ids.size(1) - 1
        logits = logits[:, start:end]

        if inf_mask is not None:
            logits = logits.masked_fill(inf_mask, -float("inf"))

        mask = batch["attention_mask"][:, start:end]
                
        if return_logprobs:
            logprobs = get_log_probs(logits, response_ids, mask, inf_mask, model_parallel=self.args.model_parallel)
            return logits, logprobs

        return logits

    def train(self):
        """
        Samples batches from `self.store`, updates model and periodically evaluates it on `self.eval_dataloader`
        """

        self.prepare_learning()
        self.iter_count = 1
        self.global_iter_count = 1
        self.nth_evaluation = 0

        self.evaluate()

        print_rank("Total Steps:", self.total_steps, "Data Epochs:", self.args.epochs)
        lm_epochs = 0        
        logging_stats = defaultdict(float)

        for training_epoch in range(self.args.training_epochs):
            for ppo_epoch in range(self.n_updates_per_batch):
                for it, batch in enumerate(self.train_dataloader):
                    if self.lm_pipeline is not None:
                        try:
                            lm_batch = next(self.lm_iterator)
                        except StopIteration:
                            lm_epochs += 1
                            print_rank(f"Another lm epoch, lm epochs: {lm_epochs}")
                            save_rank(f"Another lm epoch, lm epochs: {lm_epochs}", os.path.join(self.args.save, "log.txt"))
                            self.lm_dataloader.sampler.set_epoch(lm_epochs)
                            self.lm_iterator = iter(self.lm_dataloader)
                            lm_batch = next(self.lm_iterator)

                    self.store.move_to_device(batch, self.device)
                    self.lm_pipeline.move_to_device(*lm_batch, self.device)
                    stats = {}

                    if self.args.model_parallel:
                        raise NotImplementedError

                    if self.args.gradient_checkpointing:
                        try: self.model.module.set_force_gradient_checkpointing(True)
                        except: self.model.module.base_model.set_force_gradient_checkpointing(True)
                    
                    input_batch = self.losses.get_input_batch(batch, lm_batch)
                    logits = self.forward_model(input_batch).logits
                    ppo_logits = logits[:batch.query_tensors.size(0)]
                    lm_logits = logits[batch.query_tensors.size(0):]

                    # forward
                    forward_time = time()
                    # compute rl-related loss on explored data
                    rl_loss, rl_loss_stats = self.losses.ppo_loss(batch, ppo_logits)
                    stats.update(rl_loss_stats)
                    # compute lm-related loss on pre-training data
                    pt_loss, pt_loss_stats = self.losses.pt_loss(lm_batch, lm_logits)
                    stats.update(pt_loss_stats)
                    
                    loss = rl_loss + self.args.lm_coef * pt_loss
                    stats["tot_loss"] = loss.item()

                    forward_time = time() - forward_time
                    
                    # backward
                    backward_time = time()
                    self.model.backward(loss)
                    backward_time = time() - backward_time

                    # step
                    step_time = time()
                    self.model.step()
                    step_time = time() - step_time

                    if self.args.gradient_checkpointing:
                        try: self.model.module.set_force_gradient_checkpointing(False)
                        except: self.model.module.base_model.set_force_gradient_checkpointing(False)

                    if self.iter_count % self.args.gradient_accumulation_steps == 0 and \
                        ((self.global_iter_count < 10000 and (self.global_iter_count % 1000 == 0)) or \
                        self.global_iter_count % self.args.save_interval == 0):
                        self.save()

                    # eval
                    if self.iter_count % self.args.gradient_accumulation_steps == 0 and \
                        ((self.global_iter_count < 1000 and (self.global_iter_count % 100 == 0)) or \
                        (self.global_iter_count % self.args.eval_interval == 0)):
                        self.evaluate()

                    elapsed_time = forward_time + backward_time + step_time
                    
                    stats["elapsed_time"] = elapsed_time
                    
                    for k in stats:
                        logging_stats[k] += stats[k]

                    # Logging
                    def get_log(log_stats, one_step_time):
                        keys = ["tot_loss", "rl_loss", "pt_loss", "pg_loss", "reg_loss", "reward", "rev_kl", "stu_lens", "mixed_lens"]
                        prefix = "train | data_epochs {:2d}/{:2d} | inner iter: {:3d}/{:3d} | ppo epoch: {:2d}/{:2d} | global iter: {:6d}/{:6d}".format(
                            self.sampler.epochs,
                            self.args.epochs,
                            it,
                            len(self.train_dataloader),
                            ppo_epoch,
                            self.n_updates_per_batch,
                            self.global_iter_count,
                            self.total_steps
                        )
                        suffix = "| lr: {:.4e} | scale: {:6.2f} | time: {:.3f} | step time: {:.3f}".format(
                            self.scheduler.get_last_lr()[0],
                            self.opt.cur_scale if hasattr(self.opt, "cur_scale") else 0,
                            elapsed_time,
                            one_step_time
                        )
                        for key in keys:
                            prefix += "| {}: {:.4f} ".format(key, log_stats.get(key, 0))
                        return prefix + suffix

                    mid_log_step = self.args.gradient_accumulation_steps // self.args.mid_log_num
                    mid_log_step = 1 if mid_log_step == 0 else mid_log_step
                    if self.iter_count % mid_log_step == 0:
                        print_rank(get_log(stats, 0))

                    if self.global_iter_count % self.args.log_interval == 0 and self.iter_count % self.args.gradient_accumulation_steps == 0:
                        logging_stats = {k:v/(self.args.log_interval*self.args.gradient_accumulation_steps) for k,v in logging_stats.items()}
                        log_str = get_log(logging_stats, logging_stats.get("elapsed_time", 0) * self.args.gradient_accumulation_steps)
                        print_rank("*" * 100)
                        print_rank(log_str)
                        print_rank(self.args.save)
                        print_rank("*" * 100)
                        save_rank(log_str, os.path.join(self.args.save, "log.txt"))
                        logging_stats = {k:0 for k in logging_stats}

                    # end
                    if (self.global_iter_count >= self.total_steps or self.sampler.epochs >= self.args.epochs):
                        if self.global_iter_count >= self.total_steps:
                            print_rank("Reached total steps {}/{}".format(self.global_iter_count, self.total_steps))
                        else:
                            print_rank("Reached data epochs {}/{}".format(self.sampler.epochs, self.args.epochs)) 
                        self.save()
                        results, preds, response_texts = self.evaluate_ppo()
                        if self.eval_lm_pipeline is not None:
                            eval_pt_results = self.evaluate_pt()
                            results.update(eval_pt_results)
                        self.save_evals(preds, results, response_texts)
                        return results
                    
                    self.iter_count += 1
                    if self.iter_count % self.args.gradient_accumulation_steps == 0:
                        self.global_iter_count += 1

                self.post_backward_callback()

            self.post_epoch_callback(training_epoch)

    def post_backward_callback(self):
        pass
        
    def post_epoch_callback(self, epoch):
        self.store.clear_history()
        # self.store.load(self.args.save)
        self.sampler.run_sample(
            self.args.num_rollouts_per_device, self.global_iter_count
        )  # Collect more rollouts for training

    def prepare_learning(self):
        self.train_dataloader = self.store.create_loader(
            self.args.batch_size, shuffle=True, num_workers=self.args.num_workers, drop_last=True
        )
        
        self.eval_dataloader = self.eval_pipeline.create_loader(
            self.args.batch_size, shuffle=False, num_workers=self.args.num_workers, drop_last=False)

        self.lm_dataloader = self.lm_pipeline.create_loader(
            self.args.batch_size, shuffle=True, num_workers=self.args.num_workers, drop_last=True)
        self.lm_iterator = iter(self.lm_dataloader)
        
        self.eval_lm_dataloader = self.eval_lm_pipeline.create_loader(
            self.args.batch_size, shuffle=False, num_workers=self.args.num_workers, drop_last=False)

        self.n_updates_per_batch = self.args.ppo_epochs
        self.total_steps = int(
            self.args.training_epochs
            * self.n_updates_per_batch
            * len(self.train_dataloader)
            / self.args.gradient_accumulation_steps
        )
        self.total_steps = min(self.total_steps, self.args.total_iters)

    def evaluate(self):
        eval_results = {}
        eval_rl_results, preds, response_texts = self.evaluate_ppo()
        eval_results.update(eval_rl_results)
        eval_pt_results = self.evaluate_pt()
        eval_results.update(eval_pt_results)
        
        response_texts = response_texts[:len(self.eval_pipeline.ppo_answers)]            
        self.save_evals(preds, eval_results, response_texts)
        
        if get_rank() == 0:
            res = compute_metrics(response_texts, self.eval_pipeline.ppo_answers)
            eval_results.update(res)
            keys = ["rougeL", "exact_match", "rev_kl", "lens", "pt_loss", "lm_loss", "kd_loss"]
            eval_log_str = "eval "
            for key in keys:
                eval_log_str += "| {}: {:.3f} ".format(key, eval_results[key])
            print_rank(eval_log_str)
            save_rank(eval_log_str, os.path.join(self.args.save, "log.txt"))

    def evaluate_ppo(self):  # noqa: C901
        # self.model.eval()
        """Samples model on `eval_prompts`, logs stats with `reward_fn` or `metric_fn` if provided"""
        stats = {}
        all_full_ids = []
        all_rev_kl = []
        all_lens = []
        
        table = []

        with torch.no_grad():
            for batch in tqdm(self.eval_dataloader, "Generation Evaluation", disable=(not get_rank() == 0)):
                batch, no_model_batch = batch
                batch, _ = self.eval_pipeline.move_to_device(batch, no_model_batch, self.device)
                gen_out = self.generate(
                    **batch,
                    return_dict_in_generate=True,
                    output_scores=True
                )
                full_ids = gen_out.sequences
                gen_logits = gen_out.scores # NOTE: [b, s, h_p]
                inf_mask = torch.isinf(gen_logits)

                all_full_ids.append(full_ids)
                
                input_ids = batch["input_ids"]
                gen_ids = full_ids[:, input_ids.size(1):]
                mask = self.get_mask(full_ids)
                mask = mask[:, input_ids.size(1)-1:input_ids.size(1)+gen_ids.size(1)-1]
                lens = torch.sum(mask, dim=-1)
                
                teacher_rewards = self.reward_fn(input_ids, gen_ids)["rewards"] # \log p(y_t | y_{<t}, x)
                _, logprobs = self.compute_logits_and_log_probs(input_ids, gen_ids, inf_mask=inf_mask, base="base") # \log q_{\theta}(y_t | y_{<t}, x)
                
                kl = get_rev_kl(teacher_rewards, logprobs, mask)
                kl = kl.sum(-1)
                
                if self.args.length_norm:
                    kl = kl / lens

                all_rev_kl.append(kl)
                all_lens.append(lens)

            all_full_ids = torch.cat(all_full_ids, dim=0)
            all_rev_kl = torch.cat(all_rev_kl, dim=0)
            all_lens = torch.cat(all_lens, dim=0)

            full_ids = all_gather(all_full_ids, dim=1, world_size=self.dp_world_size, group=self.dp_group, op="stack")
            full_ids = full_ids.view(-1, full_ids.size(-1))

            prompt_ids = full_ids[:, :self.eval_pipeline.max_prompt_length]
            all_rev_kl = all_gather(all_rev_kl, dim=0, world_size=self.dp_world_size, group=self.dp_group)
            stats["rev_kl"] = all_rev_kl.mean()
            all_lens = all_gather(all_lens, dim=0, world_size=self.dp_world_size, group=self.dp_group)
            stats["lens"] = all_lens.float().mean()

            response_texts = []
            if get_rank() == 0:
                prompt_texts = self.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
                response_texts = self.tokenizer.batch_decode(full_ids[:, self.eval_pipeline.max_prompt_length:], skip_special_tokens=True)
                gen_texts = [p + g for p, g in zip(prompt_texts, response_texts)]

                columns = ["prompts"]
                columns_data = [prompt_texts]
                # in online setting, compute the reward for validation
                columns.append("samples")
                if isinstance(gen_texts[0], str):
                    columns_data.append(gen_texts)
                else:
                    columns_data.append(gen_texts.tolist())

                table.append(list(zip(*columns_data)))

        # Log and display evaluation metrics
        if get_rank() == 0:
            rows = sum(list(map(list, zip(*table))), [])

            # Add metrics/rewards to the table's title
            table_title = f"Evaluation #{self.nth_evaluation}"
            for k, x in stats.items():
                if k.startswith("reward") or k.startswith("metrics"):
                    table_title += f" {k}: {significant(x)}"

            rich_table = Table(*columns, title=table_title, show_lines=True)

            for ix in range(min(3, len(rows))):
                rich_table.add_row(*[str(significant(x)) for x in rows[ix]])

            try:
                Console().print(rich_table)
            except:
                pass

        self.nth_evaluation += 1
        return stats, table, response_texts

    def evaluate_pt(self):
        all_pt_losses = []
        all_lm_losses = []
        all_kd_losses = []
        for batch in tqdm(self.eval_lm_dataloader, desc="LM Evaluation", disable=(not get_rank() == 0)):
            self.eval_lm_pipeline.move_to_device(*batch, self.device)
            model_batch, _ = batch
            outputs = self.model(**model_batch, return_dict=True, use_cache=False)
            logits = outputs.logits
            with torch.no_grad():
                _, stats = self.losses.pt_loss(batch, logits)
                all_pt_losses.append(stats["pt_loss"])
                all_lm_losses.append(stats["lm_loss"])
                all_kd_losses.append(stats["ds_loss"])
        
        all_pt_losses = torch.tensor(all_pt_losses, device=self.device)
        eval_pt_loss = all_gather(all_pt_losses, dim=0, world_size=self.dp_world_size, group=self.dp_group).mean().item()
        
        all_lm_losses = torch.tensor(all_lm_losses, device=self.device)
        eval_lm_loss = all_gather(all_lm_losses, dim=0, world_size=self.dp_world_size, group=self.dp_group).mean().item()
        
        all_kd_losses = torch.tensor(all_kd_losses, device=self.device)
        eval_kd_loss = all_gather(all_kd_losses, dim=0, world_size=self.dp_world_size, group=self.dp_group).mean().item()
        
        results = {"pt_loss": eval_pt_loss, "lm_loss": eval_lm_loss, "kd_loss": eval_kd_loss}
        
        return results
    
    def save(self, directory: Optional[str] = None):
        """Creates a checkpoint of the optimizer, scheduler and model"""
        """Creates checkpoint of optimizer, scheduler and a model"""
        base_ckpt_path = directory or self.args.save
        ckpt_dir = os.path.join(base_ckpt_path, f"{self.global_iter_count}")
        os.makedirs(ckpt_dir, exist_ok=True)
        if self.args.model_parallel:
            raise NotImplementedError
        else:
            if get_rank() == 0:
                self.model.module.base_model.save_pretrained(ckpt_dir, safe_serialization=False)
                # torch.save(self.model.module.value_model.state_dict(), os.path.join(ckpt_dir, "value_model.ckpt"))
                print(f"Model save to {ckpt_dir}")
                self.tokenizer.save_pretrained(ckpt_dir)

    def save_evals(self, preds, results, response_texts, directory: Optional[str] = None):
        """Creates a checkpoint of the optimizer, scheduler and model"""
        """Creates checkpoint of optimizer, scheduler and a model"""
        base_ckpt_path = directory or self.args.save
        save_dir = os.path.join(base_ckpt_path, "eval", f"{self.global_iter_count}")
        os.makedirs(save_dir, exist_ok=True)
        
        if get_rank() == 0:
            torch.save(preds, os.path.join(save_dir, "preds.pt"))
            torch.save(results, os.path.join(save_dir, "results.pt"))
            with open(os.path.join(save_dir, "answers.jsonl"), "w") as f:
                for resp in response_texts:
                    f.write(json.dumps({"text": resp}) + "\n")

    def push_to_store(self, data):
        self.store.push(data)
         
    def generate(self, input_ids, attention_mask=None, mode="base", teacher_mixed_sample=False, **kwargs):
        """Wraps hf's `generate` adding some specific method's defaults"""
        input_ids = input_ids.to(self.device)
        if attention_mask is not None:
            attention_mask = attention_mask.to(self.device)

        kwargs = dict(self.generate_kwargs, **kwargs)

        if mode == "base":
            model = self.model.module
        elif mode == "teacher":
            model = self.teacher_model
        else:
            raise NotImplementedError

        mix_in_model, mix_in_alpha = None, None
        if teacher_mixed_sample:
            mix_in_model = self.teacher_model
            mix_in_alpha = self.args.teacher_mixed_alpha

        with torch.no_grad():
            
            generation_config = GenerationConfig(**kwargs)
            
            max_new_tokens = generation_config.max_length - input_ids.size(1)
            gen = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                generation_config=generation_config,
                max_new_tokens=max_new_tokens,
                mix_in_model=mix_in_model,
                mix_in_alpha=mix_in_alpha
            )
            
            gen.sequences = F.pad(
                gen.sequences,
                (0, self.max_length - gen.sequences.shape[1]),
                value=self.tokenizer.pad_token_id,
            )
            
            if gen.scores is not None:
                gen.scores = torch.stack(gen.scores, dim=1)
                gen.scores = torch.cat([
                    gen.scores, 
                    torch.zeros(
                        gen.scores.size(0),
                        self.max_length - self.args.max_prompt_length - gen.scores.size(1),
                        gen.scores.size(2),
                        device=gen.scores.device)],
                    dim=1)
                
            # NOTE: scores: [b, s, h_p]

        return gen

================================================
FILE: minillm/utils.py
================================================
import math
from enum import Enum
from numbers import Number
from typing import Tuple

import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR
from accelerate import init_empty_weights


from transformers import (
    AutoModelForCausalLM,
    AutoConfig,
)


def get_entropy(gen_logits, inf_mask, mask, model_parallel=False):
    inf_mask = torch.isinf(gen_logits) | inf_mask
    if model_parallel:
        raise NotImplementedError
    else:
        full_probs = F.softmax(gen_logits, dim=-1, dtype=torch.float32)
        full_logprobs = F.log_softmax(gen_logits, dim=-1, dtype=torch.float32)
        full_logprobs = full_logprobs.masked_fill(inf_mask, 0)        
        ent = -torch.sum(full_probs * full_logprobs, dim=-1)
    ent = ent * mask    
    return ent


def get_log_probs(logits, ids, mask, inf_mask=None, model_parallel=False):
    if model_parallel:
        raise NotImplementedError
    else:
        logprobs = F.log_softmax(logits, dim=-1)
        if inf_mask is not None:
            logprobs = logprobs.masked_fill(inf_mask, -float("inf"))
        logprobs = torch.gather(logprobs, dim=-1, index=ids.unsqueeze(-1)).squeeze(-1)
    logprobs = logprobs.masked_fill(~(mask.bool()), 0)
    
    # we ensure that the selected logprobs are not inf or nan
    assert all((~torch.isinf(logprobs.view(-1))) & (~torch.isnan(logprobs.view(-1))))
    
    return logprobs


def get_x_entropy(logits_1, logits_2, inf_mask, mask, model_parallel=False):
    inf_mask = torch.isinf(logits_1) | torch.isinf(logits_2) | inf_mask
    if model_parallel:
        raise NotImplementedError
    else:
        full_probs = F.softmax(logits_1, dim=-1, dtype=torch.float32)
        full_logprobs = F.log_softmax(logits_2, dim=-1, dtype=torch.float32)
        full_logprobs = full_logprobs.masked_fill(inf_mask, 0)
        xent = -torch.sum(full_probs * full_logprobs, dim=-1)
    xent = xent * mask
    return xent


def get_rev_kl(log_p, log_q, mask):
    log_ratio = (log_p - log_q) * mask
    kl = log_ratio.float().exp() - 1 - log_ratio
    return kl


def get_global_statistics(xs: torch.Tensor) -> Tuple[float, float, int]:
    """
    Computes element-wise mean and variance of the tensor across processes
    """
    sum_and_count = torch.tensor([xs.sum(), xs.numel()], device=xs.device)
    dist.all_reduce(sum_and_count, dist.ReduceOp.SUM)
    global_sum, count = sum_and_count
    global_mean = global_sum / count

    sum_var = torch.sum((xs - global_mean) ** 2)
    dist.all_reduce(sum_var, dist.ReduceOp.SUM)
    global_var = sum_var / count
    return global_mean, global_var, count


def whiten(xs: torch.Tensor, shift_mean=True, distributed=True) -> torch.Tensor:
    """Whitens values"""
    if distributed and dist.is_initialized():
        mean, var, _ = get_global_statistics(xs)
    else:
        var, mean = torch.var_mean(xs)

    whitened = (xs - mean) * torch.rsqrt(var + 1e-8)
    if not shift_mean:
        whitened += mean
    return whitened


def significant(x: Number, ndigits=2) -> Number:
    """
    Cut the number up to its `ndigits` after the most significant
    """
    if isinstance(x, torch.Tensor):
        x = x.item()

    if not isinstance(x, Number) or x == 0:
        return x

    return round(x, ndigits - int(math.floor(math.log10(abs(x)))))


class OptimizerName(str, Enum):
    """Supported optimizer names"""

    ADAM: str = "adam"
    ADAMW: str = "adamw"
    ADAM_8BIT_BNB: str = "adam_8bit_bnb"
    ADAMW_8BIT_BNB: str = "adamw_8bit_bnb"
    SGD: str = "sgd"


def get_optimizer_class(name: OptimizerName):
    """
    Returns the optimizer class with the given name

    Args:
        name (str): Name of the optimizer as found in `OptimizerNames`
    """
    if name == OptimizerName.ADAM:
        return torch.optim.Adam
    if name == OptimizerName.ADAMW:
        return torch.optim.AdamW
    if name == OptimizerName.SGD.value:
        return torch.optim.SGD
    supported_optimizers = [o.value for o in OptimizerName]
    raise ValueError(
        f"`{name}` is not a supported optimizer. "
        f"Supported optimizers are: {supported_optimizers}"
    )


class SchedulerName(str, Enum):
    """Supported scheduler names"""

    COSINE_ANNEALING = "cosine_annealing"
    LINEAR = "linear"


def get_scheduler_class(name: SchedulerName):
    """
    Returns the scheduler class with the given name
    """
    if name == SchedulerName.COSINE_ANNEALING:
        return CosineAnnealingLR
    if name == SchedulerName.LINEAR:
        return LinearLR
    supported_schedulers = [s.value for s in SchedulerName]
    raise ValueError(
        f"`{name}` is not a supported scheduler. "
        f"Supported schedulers are: {supported_schedulers}"
    )

================================================
FILE: rouge_metric.py
================================================
import string
import json
import os
import argparse
from rouge_score import rouge_scorer
from transformers import AutoTokenizer


default_rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

# adapted the flowing from Squad v1.1 evaluation, without removing the articles.
def normalize_answer(s):
    """Lower text and remove punctuation, and extra whitespace."""

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_punc(lower(s)))


def exact_match(prediction, ground_truth, xlingual=False):
    return (normalize_answer(prediction) == normalize_answer(ground_truth))


def rouge(prediction, ground_truth, xlingual=False):
    scorer = default_rouge_scorer
    scores = scorer.score(prediction=prediction, target=ground_truth)
    return scores["rougeL"].fmeasure


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths, xlingual=False):
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth, xlingual=xlingual)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)


def compute_metrics(predictions, references, xlingual=False):
    # assert len(predictions) == len(references), f"# of predictions {len(predictions)} doesn't match # of references {len(references)}."
    
    min_length = min(len((predictions)), len(references))
    predictions = predictions[:min_length]
    references = references[:min_length]
    
    em, rougeL = 0, 0
    for pred, gold in zip(predictions, references):
        assert isinstance(gold, list)
        em += metric_max_over_ground_truths(
            exact_match, prediction=pred, ground_truths=gold, xlingual=xlingual
        )
        rougeL += metric_max_over_ground_truths(
            rouge, prediction=pred, ground_truths=gold, xlingual=xlingual
        )
    em = 100.0 * em / len(references)
    rougeL = 100.0 * rougeL / len(references)
    metrics = {"exact_match": em, "rougeL": rougeL}
    metrics = {k: round(v, 4) for k, v in metrics.items()}
    return metrics


def compute_grouped_metrics(predictions, references, groups, xlingual=False):
    assert len(predictions) == len(references) == len(groups)

    examples_by_group = {}
    for pred, gold, group in zip(predictions, references, groups):
        if group not in examples_by_group:
            examples_by_group[group] = []
        examples_by_group[group].append((pred, gold))
    
    results = {}
    for group, group_examples in examples_by_group.items():
        task_predictions, task_references = zip(*group_examples)
        group_metrics = compute_metrics(task_predictions, task_references, xlingual=xlingual)
        for metric, value in group_metrics.items():
            results[f"{metric}_for_{group}"] = value
    return results


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--prediction_file", required=True,
        help="Jsonl file with each line corresponding to a prediction. " 
             "Each json object should have an `id` and a `prediction` key.")
    parser.add_argument(
        "--reference_file", required=True,
        help="Jsonl file with each line corresponding to a reference. " 
             "Each json object should have an `id` and a `references` key. "
             "`task_id`, `task_category` and `task_track` are optional, which will be used to "
             "compute the per-task performance, per-category performance and the performance for default (english) / xlingual Tracks.")
    parser.add_argument(
        "--output_file",
        help="Jsonl file to write the results to.")
    parser.add_argument(
        "--model_name",
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()

    references = []
    with open(args.reference_file) as fin:
        for line in fin:
            instance = json.loads(line)
            if isinstance(instance["output"], list):
                references.append(instance["output"])
            else:
                references.append([instance["output"]])

    predictions = []
    with open(args.prediction_file) as fin:
        for line in fin:
            prediction = json.loads(line)
            predictions.append(prediction["text"])

    predictions = predictions[:1000]

    references = references[:len(predictions)]

    results = compute_metrics(predictions, references, xlingual=False)

    print(results)

    if args.output_file:
        os.makedirs(args.output_file, exist_ok=True)
        with open(os.path.join(args.output_file, f"{args.model_name}.json"), "w") as fout:
            json.dump(results, fout, indent=2)
            

================================================
FILE: scripts/gpt2/distillm/train_0.1B_1.5B.sh
================================================
#! /bin/bash

MASTER_ADDR=localhost
MASTER_PORT=${2-2012}
NNODES=1
NODE_RANK=0
GPUS_PER_NODE=${3-16}

DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
                  --nnodes $NNODES \
                  --node_rank $NODE_RANK \
                  --master_addr $MASTER_ADDR \
                  --master_port $MASTER_PORT"

# model
BASE_PATH=${1-"/home/MiniLLM"}
CKPT_NAME="gpt2-base"
CKPT="${BASE_PATH}/results/gpt2/train/init/${CKPT_NAME}"
TEACHER_CKPT_NAME="xlarge-sft"
TEACHER_CKPT="${BASE_PATH}/results/gpt2/train/sft/gpt2-xlarge/"
# data
DATA_DIR="${BASE_PATH}/processed_data/dolly/full/gpt2/"
LM_DATA_DIR="${BASE_PATH}/processed_data/openwebtext/gpt2/512/10M/"
# hp
BATCH_SIZE=8
LR=0.0005
GRAD_ACC=1
EVAL_BATCH_SIZE=64
# length
MAX_LENGTH=512
# runtime
SAVE_PATH="${BASE_PATH}/results/gpt2/train/distill_0.1B_1.5B_final2"
# seed
SEED=10


OPTS=""
# model
OPTS+=" --base-path ${BASE_PATH}"
OPTS+=" --model-path ${CKPT}"
OPTS+=" --teacher-model-path ${TEACHER_CKPT}"
OPTS+=" --ckpt-name ${CKPT_NAME}"
OPTS+=" --teacher-ckpt-name ${TEACHER_CKPT_NAME}"
OPTS+=" --teacher-model-fp16"
OPTS+=" --n-gpu ${GPUS_PER_NODE}"
# data
OPTS+=" --data-dir ${DATA_DIR}"
OPTS+=" --lm-data-dir ${LM_DATA_DIR}"
OPTS+=" --num-workers 4"
OPTS+=" --dev-num 1000"
# hp
OPTS+=" --lr ${LR}"
OPTS+=" --batch-size ${BATCH_SIZE}"
OPTS+=" --eval-batch-size ${EVAL_BATCH_SIZE}"
OPTS+=" --gradient-accumulation-steps ${GRAD_ACC}"
OPTS+=" --warmup-iters 0"
OPTS+=" --lr-decay-style cosine"
OPTS+=" --weight-decay 1e-2"
OPTS+=" --clip-grad 1.0"
OPTS+=" --epochs 20"
OPTS+=" --kd-ratio 1.0"
# length
OPTS+=" --max-length ${MAX_LENGTH}"
OPTS+=" --max-prompt-length 256"
# runtime
OPTS+=" --do-train"
OPTS+=" --do-valid"
OPTS+=" --eval-gen"
OPTS+=" --save-interval -1"
OPTS+=" --eval-interval -1"
OPTS+=" --log-interval 4"
OPTS+=" --mid-log-num -1"
OPTS+=" --save ${SAVE_PATH}"
# seed
OPTS+=" --seed ${SEED}"
# deepspeed
OPTS+=" --deepspeed"
OPTS+=" --deepspeed_config ${BASE_PATH}/configs/deepspeed/ds_config.json"
# type
OPTS+=" --type adaptive-sfkl"
# gen
OPTS+=" --do-sample"
OPTS+=" --top-k 0"
OPTS+=" --top-p 1.0"
OPTS+=" --temperature 1.0"
# distillm
OPTS+=" --student-gen"
OPTS+=" --gen-num-beams 1"
OPTS+=" --gen-top-p 1.0"
OPTS+=" --init-threshold 0.0"
OPTS+=" --loss-eps 0.1"
OPTS+=" --capacity 1000"


export NCCL_DEBUG=""
export WANDB_DISABLED=True
export TF_CPP_MIN_LOG_LEVEL=3
export PYTHONPATH=${BASE_PATH}
CMD="torchrun ${DISTRIBUTED_ARGS} ${BASE_PATH}/finetune.py ${OPTS} $@"

echo ${CMD}
echo "PYTHONPATH=${PYTHONPA
Download .txt
gitextract_j6ug111_/

├── .gitignore
├── README.md
├── arguments.py
├── configs/
│   ├── deepspeed/
│   │   ├── ds_config.json
│   │   ├── ds_config_fp32.json
│   │   ├── ds_config_zero2.json
│   │   └── ds_config_zero2_offload.json
│   └── hostfiles/
│       ├── node_0_1
│       ├── node_0_1_2_3
│       ├── node_1_2
│       └── node_2_3
├── data_utils/
│   ├── distributed_indexed.py
│   ├── indexed_dataset.py
│   ├── lm_datasets.py
│   └── prompt_datasets.py
├── distillm/
│   ├── __init__.py
│   ├── buffer.py
│   ├── losses.py
│   └── sampler.py
├── evaluate.py
├── evaluate_main.py
├── finetune.py
├── generate.py
├── install.sh
├── minillm/
│   ├── __init__.py
│   ├── data_types.py
│   ├── losses.py
│   ├── model.py
│   ├── pipelines.py
│   ├── reward.py
│   ├── sampler.py
│   ├── storages.py
│   ├── trainer.py
│   └── utils.py
├── rouge_metric.py
├── scripts/
│   ├── gpt2/
│   │   ├── distillm/
│   │   │   ├── train_0.1B_1.5B.sh
│   │   │   ├── train_0.3B_1.5B.sh
│   │   │   └── train_0.7B_1.5B.sh
│   │   ├── eval/
│   │   │   ├── eval_main_dolly.sh
│   │   │   ├── eval_main_self_inst.sh
│   │   │   ├── eval_main_sinst.sh
│   │   │   ├── eval_main_uinst.sh
│   │   │   ├── eval_main_vicuna.sh
│   │   │   └── run_eval.sh
│   │   ├── gkd/
│   │   │   ├── gkd_base.sh
│   │   │   ├── gkd_large.sh
│   │   │   └── gkd_medium.sh
│   │   ├── imitkd/
│   │   │   ├── imitkd_base.sh
│   │   │   ├── imitkd_large.sh
│   │   │   └── imitkd_medium.sh
│   │   ├── init/
│   │   │   ├── init_base.sh
│   │   │   ├── init_large.sh
│   │   │   └── init_medium.sh
│   │   ├── kd/
│   │   │   ├── kd_base.sh
│   │   │   ├── kd_large.sh
│   │   │   └── kd_medium.sh
│   │   ├── minillm/
│   │   │   ├── train_base_xl.sh
│   │   │   ├── train_large_xl.sh
│   │   │   └── train_medium_xl.sh
│   │   ├── seqkd/
│   │   │   ├── seqkd_base.sh
│   │   │   ├── seqkd_large.sh
│   │   │   └── seqkd_medium.sh
│   │   ├── sft/
│   │   │   ├── sft_base.sh
│   │   │   ├── sft_large.sh
│   │   │   ├── sft_medium.sh
│   │   │   └── sft_xlarge.sh
│   │   └── tools/
│   │       ├── generate_data_seqkd.sh
│   │       ├── process_data_dolly.sh
│   │       ├── process_data_pretrain.sh
│   │       └── process_pseudo_data_seqkd.sh
│   ├── openllama2/
│   │   ├── distillm/
│   │   │   └── train_3B_7B_teacher_lora.sh
│   │   ├── eval/
│   │   │   ├── eval_main_dolly_lora.sh
│   │   │   ├── eval_main_self_inst_lora.sh
│   │   │   ├── eval_main_sinst_lora.sh
│   │   │   ├── eval_main_uinst_lora.sh
│   │   │   ├── eval_main_vicuna_lora.sh
│   │   │   └── run_eval.sh
│   │   ├── gkd/
│   │   │   └── gkd_3B_7B_teacher_lora.sh
│   │   ├── imitkd/
│   │   │   └── imitkd_3B_7B_teacher_lora.sh
│   │   ├── init/
│   │   │   └── sft_3B_lora.sh
│   │   ├── kd/
│   │   │   └── kd_3B_7B_teacher_lora.sh
│   │   ├── minillm/
│   │   │   └── train_3B_7B_lora.sh
│   │   ├── seqkd/
│   │   │   └── seqkd_3B_7B_teacher_lora.sh
│   │   ├── sft/
│   │   │   ├── sft_3B_lora.sh
│   │   │   └── sft_7B_lora.sh
│   │   └── tools/
│   │       ├── generate_data_seqkd.sh
│   │       ├── process_data_dolly.sh
│   │       ├── process_data_pretrain.sh
│   │       └── process_pseudo_data_seqkd.sh
│   └── opt/
│       ├── distillm/
│       │   ├── train_0.1B_2.7B.sh
│       │   ├── train_0.3B_2.7B.sh
│       │   └── train_1.3B_2.7B.sh
│       ├── eval/
│       │   ├── eval_main_dolly.sh
│       │   ├── eval_main_self_inst.sh
│       │   ├── eval_main_sinst.sh
│       │   ├── eval_main_uinst.sh
│       │   ├── eval_main_vicuna.sh
│       │   └── run_eval.sh
│       ├── gkd/
│       │   ├── gkd_0.1B_2.7B.sh
│       │   ├── gkd_0.3B_2.7B.sh
│       │   └── gkd_1.3B_2.7B.sh
│       ├── imitkd/
│       │   ├── imitkd_0.1B_2.7B.sh
│       │   ├── imitkd_0.3B_2.7B.sh
│       │   └── imitkd_1.3B_2.7B.sh
│       ├── init/
│       │   ├── init_0.1B.sh
│       │   ├── init_0.3B.sh
│       │   └── init_1.3B.sh
│       ├── kd/
│       │   ├── kd_0.1B_2.7B.sh
│       │   ├── kd_0.3B_2.7B.sh
│       │   └── kd_1.3B_2.7B.sh
│       ├── minillm/
│       │   ├── train_0.1B_2.7B.sh
│       │   ├── train_0.3B_2.7B.sh
│       │   └── train_1.3B_2.7B.sh
│       ├── seqkd/
│       │   ├── seqkd_0.1B_2.7B.sh
│       │   ├── seqkd_0.3B_2.7B.sh
│       │   └── seqkd_1.3B_2.7B.sh
│       ├── sft/
│       │   ├── sft_0.1B.sh
│       │   ├── sft_0.3B.sh
│       │   ├── sft_1.3B.sh
│       │   └── sft_2.7B.sh
│       └── tools/
│           ├── generate_data_seqkd.sh
│           ├── process_data_dolly.sh
│           ├── process_data_pretrain.sh
│           └── process_pseudo_data_seqkd.sh
├── tools/
│   ├── convert_mp.py
│   ├── get_openwebtext.py
│   ├── process_data_dolly.py
│   └── process_data_pretrain.py
├── train_minillm.py
└── utils.py
Download .txt
SYMBOL INDEX (283 symbols across 28 files)

FILE: arguments.py
  function add_model_args (line 22) | def add_model_args(parser: argparse.ArgumentParser):
  function add_runtime_args (line 43) | def add_runtime_args(parser: argparse.ArgumentParser):
  function add_data_args (line 68) | def add_data_args(parser: argparse.ArgumentParser):
  function add_hp_args (line 99) | def add_hp_args(parser: argparse.ArgumentParser):
  function add_ppo_args (line 148) | def add_ppo_args(parser: argparse.ArgumentParser):
  function add_minillm_args (line 163) | def add_minillm_args(parser: argparse.ArgumentParser):
  function add_distillm_args (line 174) | def add_distillm_args(parser: argparse.ArgumentParser):
  function add_gen_args (line 197) | def add_gen_args(parser: argparse.ArgumentParser):
  function add_peft_args (line 211) | def add_peft_args(parser: argparse.ArgumentParser):
  function get_args (line 225) | def get_args():

FILE: data_utils/distributed_indexed.py
  function code (line 41) | def code(dtype):
  function index_file_path (line 48) | def index_file_path(prefix_path):
  function data_file_path (line 52) | def data_file_path(prefix_path):
  class DistributedMMapIndexedDataset (line 56) | class DistributedMMapIndexedDataset(torch.utils.data.Dataset):
    class Index (line 57) | class Index(object):
      method __init__ (line 59) | def __init__(self, path):
      method __del__ (line 89) | def __del__(self):
      method dtype (line 94) | def dtype(self):
      method sizes (line 98) | def sizes(self):
      method doc_idx (line 102) | def doc_idx(self):
      method __getitem__ (line 105) | def __getitem__(self, i):
      method __len__ (line 108) | def __len__(self):
    method __init__ (line 111) | def __init__(self, path, name, rank_number, rank_total, cache = None):
    method _probe_data_path (line 133) | def _probe_data_path(self, path, name, rank_total):
    method __getstate__ (line 150) | def __getstate__(self):
    method __setstate__ (line 153) | def __setstate__(self, state):
    method _do_init (line 157) | def _do_init(self, path, name, cache, state):
    method __del__ (line 171) | def __del__(self):
    method __len__ (line 178) | def __len__(self):
    method _next_file (line 181) | def _next_file(self):
    method __relative_idx (line 188) | def __relative_idx(self, idx):
    method __slice_item (line 192) | def __slice_item(self, start, stop):
    method __getitem__ (line 199) | def __getitem__(self, idx):
    method sizes (line 209) | def sizes(self):
    method exists (line 212) | def exists(self, path):

FILE: data_utils/indexed_dataset.py
  function __best_fitting_dtype (line 23) | def __best_fitting_dtype(vocab_size=None):
  function get_available_dataset_impl (line 30) | def get_available_dataset_impl():
  function infer_dataset_impl (line 34) | def infer_dataset_impl(path):
  function make_builder (line 50) | def make_builder(out_file, impl, dtype):
  function make_dataset (line 57) | def make_dataset(path, impl, skip_warmup=False):
  function dataset_exists (line 74) | def dataset_exists(path, impl):
  function read_longs (line 81) | def read_longs(f, n):
  function write_longs (line 87) | def write_longs(f, a):
  function code (line 104) | def code(dtype):
  function index_file_path (line 111) | def index_file_path(prefix_path):
  function data_file_path (line 115) | def data_file_path(prefix_path):
  function create_doc_idx (line 119) | def create_doc_idx(sizes):
  class IndexedDataset (line 127) | class IndexedDataset(torch.utils.data.Dataset):
    method __init__ (line 131) | def __init__(self, path):
    method read_index (line 137) | def read_index(self, path):
    method read_data (line 155) | def read_data(self, path):
    method check_index (line 158) | def check_index(self, i):
    method __del__ (line 162) | def __del__(self):
    method __getitem__ (line 167) | def __getitem__(self, idx):
    method __len__ (line 191) | def __len__(self):
    method num_tokens (line 194) | def num_tokens(self, index):
    method size (line 197) | def size(self, index):
    method exists (line 201) | def exists(path):
    method supports_prefetch (line 207) | def supports_prefetch(self):
  class IndexedCachedDataset (line 211) | class IndexedCachedDataset(IndexedDataset):
    method __init__ (line 213) | def __init__(self, path):
    method supports_prefetch (line 219) | def supports_prefetch(self):
    method prefetch (line 222) | def prefetch(self, indices):
    method __getitem__ (line 247) | def __getitem__(self, idx):
  class IndexedDatasetBuilder (line 264) | class IndexedDatasetBuilder(object):
    method __init__ (line 275) | def __init__(self, out_file, dtype=np.int32):
    method add_item (line 284) | def add_item(self, tensor):
    method end_document (line 291) | def end_document(self):
    method merge_file_ (line 294) | def merge_file_(self, another_file):
    method finalize (line 314) | def finalize(self, index_file):
  function _warmup_mmap_file (line 329) | def _warmup_mmap_file(path):
  class MMapIndexedDataset (line 335) | class MMapIndexedDataset(torch.utils.data.Dataset):
    class Index (line 336) | class Index(object):
      method writer (line 340) | def writer(cls, path, dtype):
      method __init__ (line 385) | def __init__(self, path, skip_warmup=False):
      method __del__ (line 422) | def __del__(self):
      method dtype (line 427) | def dtype(self):
      method sizes (line 431) | def sizes(self):
      method doc_idx (line 435) | def doc_idx(self):
      method __getitem__ (line 439) | def __getitem__(self, i):
      method __len__ (line 442) | def __len__(self):
    method __init__ (line 445) | def __init__(self, path, skip_warmup=False):
    method __getstate__ (line 454) | def __getstate__(self):
    method __setstate__ (line 457) | def __setstate__(self, state):
    method _do_init (line 460) | def _do_init(self, path, skip_warmup):
    method __del__ (line 472) | def __del__(self):
    method __len__ (line 477) | def __len__(self):
    method __getitem__ (line 481) | def __getitem__(self, idx):
    method get (line 501) | def get(self, idx, offset=0, length=None):
    method sizes (line 516) | def sizes(self):
    method supports_prefetch (line 530) | def supports_prefetch(self):
    method exists (line 534) | def exists(path):
  class MMapIndexedDatasetBuilder (line 540) | class MMapIndexedDatasetBuilder(object):
    method __init__ (line 541) | def __init__(self, out_file, dtype=np.int64):
    method add_item (line 547) | def add_item(self, tensor):
    method end_document (line 552) | def end_document(self):
    method merge_file_ (line 555) | def merge_file_(self, another_file):
    method finalize (line 567) | def finalize(self, index_file):

FILE: data_utils/lm_datasets.py
  class LMTrainDataset (line 15) | class LMTrainDataset(Dataset):
    method __init__ (line 16) | def __init__(self, args, tokenizer, path, split, num, ratio, rng_sampl...
    method __len__ (line 40) | def __len__(self):
    method __getitem__ (line 43) | def __getitem__(self, index):
    method _get_lm (line 46) | def _get_lm(self, index):
    method _process_lm (line 53) | def _process_lm(self, i, samp, model_data, no_model_data, gen_data):
    method move_to_device (line 77) | def move_to_device(self, model_data, no_model_data, gen_data, device):
    method collate (line 89) | def collate(self, samples):

FILE: data_utils/prompt_datasets.py
  class PromptDataset (line 13) | class PromptDataset(Dataset):
    method __init__ (line 14) | def __init__(self, args, tokenizer, split, data_path=None, num=-1):
    method __len__ (line 50) | def __len__(self):
    method load_data_json (line 53) | def load_data_json(self, data_path):
    method load_data_txt (line 80) | def load_data_txt(self, data_path):
    method verbalizer (line 93) | def verbalizer(self):
    method __getitem__ (line 96) | def __getitem__(self, index: int):
    method collate (line 114) | def collate(self, samples):
    method move_to_device (line 141) | def move_to_device(self, model_batch, no_model_batch, device):

FILE: distillm/buffer.py
  class ReplayBuffer (line 16) | class ReplayBuffer:
    method __init__ (line 17) | def __init__(self, args):
    method __len__ (line 28) | def __len__(self):
    method sample (line 31) | def sample(self):
    method move_to_device (line 54) | def move_to_device(self, model_data, no_model_data, device):
    method move_to_memory (line 63) | def move_to_memory(self, model_data, no_model_data):

FILE: distillm/losses.py
  function forward_kl (line 4) | def forward_kl(logits, teacher_logits, no_model_batch):
  function reverse_kl (line 14) | def reverse_kl(logits, teacher_logits, no_model_batch):
  function symmetric_kl (line 26) | def symmetric_kl(logits, teacher_logits, no_model_batch, lam=0.9):
  function js_distance (line 32) | def js_distance(logits, teacher_logits, no_model_batch, lam=0.9):
  function tv_distance (line 55) | def tv_distance(logits, teacher_logits, no_model_batch):
  function skewed_forward_kl (line 66) | def skewed_forward_kl(logits, teacher_logits, no_model_batch, lam=0.1):
  function skewed_reverse_kl (line 80) | def skewed_reverse_kl(logits, teacher_logits, no_model_batch, lam=0.1):

FILE: distillm/sampler.py
  class SampleGenerator (line 6) | class SampleGenerator():
    method __init__ (line 7) | def __init__(self, args, tokenizer):
    method run_sample (line 26) | def run_sample(self, model, gen_data):

FILE: evaluate.py
  function setup_model (line 23) | def setup_model(args, ds_config, device):
  function main (line 44) | def main():

FILE: evaluate_main.py
  function prepare_dataset_main (line 22) | def prepare_dataset_main(args, tokenizer):
  function run_model (line 29) | def run_model(args, tokenizer, model, dataset: PromptDataset, epoch, dev...
  function evaluate_main (line 124) | def evaluate_main(args, tokenizer, model, dataset: PromptDataset, split,...

FILE: finetune.py
  function get_teacher_model (line 50) | def get_teacher_model(args, device):
  function get_optimizer (line 77) | def get_optimizer(args, model):
  function get_learning_rate_scheduler (line 95) | def get_learning_rate_scheduler(args, optimizer):
  function setup_model_and_optimizer (line 119) | def setup_model_and_optimizer(args, ds_config, device, set_optim=True):
  function prepare_dataset (line 143) | def prepare_dataset(args, tokenizer):
  function pt_loss (line 162) | def pt_loss(args, model, model_batch, no_model_batch):
  function get_distil_loss (line 171) | def get_distil_loss(args, tokenizer, model, teacher_model, model_batch, ...
  function get_teacher_lm_loss (line 196) | def get_teacher_lm_loss(args, tokenizer, model, teacher_model, model_bat...
  function finetune (line 238) | def finetune(args, tokenizer: AutoTokenizer, model: deepspeed.DeepSpeedE...
  function evaluate (line 431) | def evaluate(args, tokenizer, model, dataset: LMTrainDataset, split, epo...
  function main (line 538) | def main():

FILE: generate.py
  function setup_model (line 28) | def setup_model(args, ds_config, device):
  function prepare_dataset (line 48) | def prepare_dataset(args, tokenizer):
  function generate (line 55) | def generate(args, tokenizer, model, dataset, device):
  function main (line 123) | def main():

FILE: minillm/__init__.py
  function train (line 10) | def train(

FILE: minillm/data_types.py
  class PromptElement (line 7) | class PromptElement:
  class PromptBatch (line 23) | class PromptBatch:
  class PPORLElement (line 39) | class PPORLElement:
  class PPORLBatch (line 80) | class PPORLBatch:

FILE: minillm/losses.py
  class Loss (line 15) | class Loss():
    method __init__ (line 16) | def __init__(self, args, trainer):
    method _get_cumsum_rewards (line 20) | def _get_cumsum_rewards(self, rewards):
    method _get_advantages_and_returns (line 27) | def _get_advantages_and_returns(
    method _pg_loss (line 58) | def _pg_loss(
    method _reg_loss (line 98) | def _reg_loss(self, query_ids, response_ids, mask, logits, inf_mask, s...
    method get_input_batch (line 110) | def get_input_batch(self, ppo_batch: PPORLBatch, pt_batch):
    method ppo_loss (line 122) | def ppo_loss(self, batch: PPORLBatch, logits):
    method pt_loss (line 194) | def pt_loss(self, batch, logits):

FILE: minillm/model.py
  class PPOModel (line 8) | class PPOModel(nn.Module):
    method __init__ (line 9) | def __init__(self, args, device):
    method forward (line 16) | def forward(self, **x):
    method generate (line 20) | def generate(self, **x):
    method set_force_gradient_checkpointing (line 23) | def set_force_gradient_checkpointing(self, value):

FILE: minillm/pipelines.py
  class PPOPipeline (line 15) | class PPOPipeline():
    method __init__ (line 16) | def __init__(self, args, tokenizer, split, ppo_data_path=None, fix_pro...
    method __len__ (line 41) | def __len__(self):
    method __getitem__ (line 44) | def __getitem__(self, index: int):
    method collate (line 60) | def collate(self, samples):
    method move_to_device (line 88) | def move_to_device(self, model_batch, no_model_batch, device):
    method create_loader (line 96) | def create_loader(self, batch_size: int, shuffle=False, drop_last: boo...
  class LMPipeline (line 110) | class LMPipeline():
    method __init__ (line 111) | def __init__(self, args, tokenizer, split, lm_data_path=None, num=-1):
    method __len__ (line 126) | def __len__(self):
    method __getitem__ (line 129) | def __getitem__(self, index):
    method _get_lm (line 132) | def _get_lm(self, index):
    method _process_lm (line 139) | def _process_lm(self, i, samp, model_data, no_model_data):
    method move_to_device (line 157) | def move_to_device(self, model_batch, no_model_batch, device):
    method collate (line 166) | def collate(self, samples):
    method create_loader (line 189) | def create_loader(self, batch_size: int, shuffle=False, drop_last: boo...

FILE: minillm/reward.py
  class Reward (line 8) | class Reward():
    method __init__ (line 9) | def __init__(self, args, tokenizer: AutoTokenizer, model: AutoModelFor...
    method get_input_batch (line 16) | def get_input_batch(self, input_ids, gen_ids, output_pos=True):
    method reward_fn (line 33) | def reward_fn(self, input_ids, gen_ids, inf_mask=None, output_pos=True):

FILE: minillm/sampler.py
  class PPOSampler (line 12) | class PPOSampler():
    method __init__ (line 18) | def __init__(
    method run_sample (line 39) | def run_sample(self, num_rollouts_per_device: int = 1024, iter_count: ...

FILE: minillm/storages.py
  class BaseRolloutStore (line 17) | class BaseRolloutStore(Dataset):
    method __init__ (line 18) | def __init__(self, capacity=-1):
    method push (line 23) | def push(self, exps: Iterable[Any]):
    method __getitem__ (line 29) | def __getitem__(self, index: int) -> PPORLElement:
    method __len__ (line 32) | def __len__(self) -> int:
    method create_loader (line 36) | def create_loader(
    method broadcast (line 53) | def broadcast(self, batch, src=0, group=None):
    method move_to_device (line 57) | def move_to_device(self, batch, device):
  class PPORolloutStorage (line 61) | class PPORolloutStorage(BaseRolloutStore):
    method __init__ (line 66) | def __init__(self, pad_token_id, seed):
    method push (line 74) | def push(self, exps: Iterable[PPORLElement]):
    method save (line 77) | def save(self, path):
    method load (line 85) | def load(self, path):
    method clear_history (line 89) | def clear_history(self):
    method export_history (line 92) | def export_history(self, location: str):
    method __getitem__ (line 104) | def __getitem__(self, index: int) -> PPORLElement:
    method __len__ (line 107) | def __len__(self) -> int:
    method collate (line 110) | def collate(self, elems: Iterable[PPORLElement]):
    method create_loader (line 170) | def create_loader(self, batch_size: int, shuffle=False, drop_last: boo...
    method broadcast (line 177) | def broadcast(self, batch: PPORLBatch, src=0, group=None):
    method move_to_device (line 181) | def move_to_device(self, batch: PPORLBatch, device):

FILE: minillm/trainer.py
  class PPOTrainer (line 43) | class PPOTrainer():
    method __init__ (line 48) | def __init__(self, args, tokenizer: AutoTokenizer, reward_fn, ds_config):
    method set_teacher_model (line 94) | def set_teacher_model(self, model):
    method set_sampler (line 97) | def set_sampler(self, sampler):
    method setup_optimizer (line 100) | def setup_optimizer(self):
    method setup_scheduler (line 114) | def setup_scheduler(self):
    method setup_ds (line 128) | def setup_ds(self, model, optimizer=None, scheduler=None):
    method add_eval_pipeline (line 139) | def add_eval_pipeline(self, eval_pipeline: PPOPipeline):
    method add_lm_pipeline (line 143) | def add_lm_pipeline(self, lm_pipeline: LMPipeline, eval_lm_pipeline: L...
    method get_model_inputs (line 147) | def get_model_inputs(
    method get_mask (line 170) | def get_mask(self, tokens):
    method forward_model (line 176) | def forward_model(self, batch):
    method compute_logits_and_log_probs (line 184) | def compute_logits_and_log_probs(self, query_ids, response_ids, inf_ma...
    method train (line 220) | def train(self):
    method post_backward_callback (line 372) | def post_backward_callback(self):
    method post_epoch_callback (line 375) | def post_epoch_callback(self, epoch):
    method prepare_learning (line 382) | def prepare_learning(self):
    method evaluate (line 406) | def evaluate(self):
    method evaluate_ppo (line 426) | def evaluate_ppo(self):  # noqa: C901
    method evaluate_pt (line 522) | def evaluate_pt(self):
    method save (line 550) | def save(self, directory: Optional[str] = None):
    method save_evals (line 565) | def save_evals(self, preds, results, response_texts, directory: Option...
    method push_to_store (line 579) | def push_to_store(self, data):
    method generate (line 582) | def generate(self, input_ids, attention_mask=None, mode="base", teache...

FILE: minillm/utils.py
  function get_entropy (line 19) | def get_entropy(gen_logits, inf_mask, mask, model_parallel=False):
  function get_log_probs (line 32) | def get_log_probs(logits, ids, mask, inf_mask=None, model_parallel=False):
  function get_x_entropy (line 48) | def get_x_entropy(logits_1, logits_2, inf_mask, mask, model_parallel=Fal...
  function get_rev_kl (line 61) | def get_rev_kl(log_p, log_q, mask):
  function get_global_statistics (line 67) | def get_global_statistics(xs: torch.Tensor) -> Tuple[float, float, int]:
  function whiten (line 82) | def whiten(xs: torch.Tensor, shift_mean=True, distributed=True) -> torch...
  function significant (line 95) | def significant(x: Number, ndigits=2) -> Number:
  class OptimizerName (line 108) | class OptimizerName(str, Enum):
  function get_optimizer_class (line 118) | def get_optimizer_class(name: OptimizerName):
  class SchedulerName (line 138) | class SchedulerName(str, Enum):
  function get_scheduler_class (line 145) | def get_scheduler_class(name: SchedulerName):

FILE: rouge_metric.py
  function normalize_answer (line 12) | def normalize_answer(s):
  function exact_match (line 28) | def exact_match(prediction, ground_truth, xlingual=False):
  function rouge (line 32) | def rouge(prediction, ground_truth, xlingual=False):
  function metric_max_over_ground_truths (line 38) | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths, ...
  function compute_metrics (line 46) | def compute_metrics(predictions, references, xlingual=False):
  function compute_grouped_metrics (line 69) | def compute_grouped_metrics(predictions, references, groups, xlingual=Fa...
  function parse_args (line 87) | def parse_args():

FILE: tools/convert_mp.py
  function main (line 24) | def main():

FILE: tools/process_data_dolly.py
  class Encoder (line 15) | class Encoder(object):
    method __init__ (line 16) | def __init__(self, args):
    method initializer (line 19) | def initializer(self):
    method encode (line 22) | def encode(self, line):
  function main (line 65) | def main():

FILE: tools/process_data_pretrain.py
  class Encoder (line 14) | class Encoder(object):
    method __init__ (line 15) | def __init__(self, args):
    method initializer (line 18) | def initializer(self):
    method encode (line 21) | def encode(self, line):
  function main (line 28) | def main():

FILE: train_minillm.py
  function get_teacher_model (line 19) | def get_teacher_model(args, device):
  function main (line 44) | def main():

FILE: utils.py
  function print_args (line 24) | def print_args(args):
  function save_rank (line 33) | def save_rank(log_str, save_path, rank=0):
  function print_rank (line 39) | def print_rank(*args, rank=0, **kwargs):
  function all_gather (line 45) | def all_gather(t, dim=0, world_size=None, group=None, op="cat"):
  function set_random_seed (line 58) | def set_random_seed(seed, mp=False):
  function init_distributed (line 69) | def init_distributed(args):
  function init_distributed_ds (line 86) | def init_distributed_ds(args):
  function initialize (line 103) | def initialize(args):
  function get_model (line 120) | def get_model(args, device):
  function get_optimizer_params (line 176) | def get_optimizer_params(args, model: nn.Module):
  function get_optimizer_params_peft (line 190) | def get_optimizer_params_peft(args, model: nn.Module):
  function get_tokenizer (line 200) | def get_tokenizer(args):
  function load_parallel (line 212) | def load_parallel(model, load_dir):
  function save_parallel (line 222) | def save_parallel(model, save_dir):
Condensed preview — 130 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (432K chars).
[
  {
    "path": ".gitignore",
    "chars": 3813,
    "preview": "# Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks\n# Edit at https://www.toptal.com/de"
  },
  {
    "path": "README.md",
    "chars": 9339,
    "preview": "# DistiLLM: Towards Streamlined Distillation for Large Language Models (ICML 2024)\n\n<a href=\"https://arxiv.org/abs/2402."
  },
  {
    "path": "arguments.py",
    "chars": 15168,
    "preview": "# coding=utf-8\n# Copyright 2020 The OpenBMB team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2."
  },
  {
    "path": "configs/deepspeed/ds_config.json",
    "chars": 377,
    "preview": "{\n    \"train_micro_batch_size_per_gpu\": 1,\n    \"gradient_accumulation_steps\": 1,\n    \"zero_optimization\": {\n        \"sta"
  },
  {
    "path": "configs/deepspeed/ds_config_fp32.json",
    "chars": 258,
    "preview": "{\n    \"train_micro_batch_size_per_gpu\": 1,\n    \"gradient_accumulation_steps\": 1,\n    \"zero_optimization\": {\n        \"sta"
  },
  {
    "path": "configs/deepspeed/ds_config_zero2.json",
    "chars": 659,
    "preview": "{\n    \"train_micro_batch_size_per_gpu\": 1,\n    \"gradient_accumulation_steps\": 1,\n    \"zero_optimization\": {\n        \"sta"
  },
  {
    "path": "configs/deepspeed/ds_config_zero2_offload.json",
    "chars": 700,
    "preview": "{\n    \"train_micro_batch_size_per_gpu\": 1,\n    \"gradient_accumulation_steps\": 1,\n    \"zero_optimization\": {\n        \"sta"
  },
  {
    "path": "configs/hostfiles/node_0_1",
    "chars": 29,
    "preview": "node-0 slots=8\nnode-1 slots=8"
  },
  {
    "path": "configs/hostfiles/node_0_1_2_3",
    "chars": 59,
    "preview": "node-0 slots=8\nnode-1 slots=8\nnode-2 slots=8\nnode-3 slots=8"
  },
  {
    "path": "configs/hostfiles/node_1_2",
    "chars": 29,
    "preview": "node-1 slots=8\nnode-2 slots=8"
  },
  {
    "path": "configs/hostfiles/node_2_3",
    "chars": 29,
    "preview": "node-2 slots=8\nnode-3 slots=8"
  },
  {
    "path": "data_utils/distributed_indexed.py",
    "chars": 7176,
    "preview": "# coding=utf-8\n# Copyright 2020 The OpenBMB team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2."
  },
  {
    "path": "data_utils/indexed_dataset.py",
    "chars": 18552,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MIT license found in the\n#"
  },
  {
    "path": "data_utils/lm_datasets.py",
    "chars": 4283,
    "preview": "import random\nimport torch\nimport os\nimport json\nimport pickle\nimport numpy as np\nfrom torch.utils.data import Dataset\nf"
  },
  {
    "path": "data_utils/prompt_datasets.py",
    "chars": 5882,
    "preview": "import random\nimport torch\nimport os\nfrom torch.utils.data import Dataset\nfrom .distributed_indexed import DistributedMM"
  },
  {
    "path": "distillm/__init__.py",
    "chars": 210,
    "preview": "from .losses import forward_kl, reverse_kl, symmetric_kl, js_distance, tv_distance\nfrom .losses import skewed_forward_kl"
  },
  {
    "path": "distillm/buffer.py",
    "chars": 3053,
    "preview": "import random\nimport torch\nimport os\nimport json\nimport pickle\nimport numpy as np\nfrom torch.utils.data import Dataset\n\n"
  },
  {
    "path": "distillm/losses.py",
    "chars": 4906,
    "preview": "import torch\nimport torch.nn.functional as F\n\ndef forward_kl(logits, teacher_logits, no_model_batch):\n    teacher_probs "
  },
  {
    "path": "distillm/sampler.py",
    "chars": 2812,
    "preview": "import torch\nimport os\nfrom transformers import GenerationConfig\n\n\nclass SampleGenerator():\n    def __init__(self, args,"
  },
  {
    "path": "evaluate.py",
    "chars": 2367,
    "preview": "import time\nimport os\n\nimport torch\nimport torch.distributed as dist\nimport deepspeed\n\nimport json\n\nfrom arguments impor"
  },
  {
    "path": "evaluate_main.py",
    "chars": 6661,
    "preview": "from data_utils.prompt_datasets import PromptDataset\nfrom transformers import GenerationConfig\nimport os\nimport nltk\nnlt"
  },
  {
    "path": "finetune.py",
    "chars": 24409,
    "preview": "import time\nimport os\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch."
  },
  {
    "path": "generate.py",
    "chars": 5155,
    "preview": "import time\nimport os\n\nimport torch\nimport torch.distributed as dist\nfrom torch.utils.data import DataLoader, Distribute"
  },
  {
    "path": "install.sh",
    "chars": 420,
    "preview": "export NCCL_DEBUG=\"\"\n# conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -"
  },
  {
    "path": "minillm/__init__.py",
    "chars": 1588,
    "preview": "from deepspeed import DeepSpeedConfig\nfrom typing import Optional\n\n# from trlx.utils.loading import get_orchestrator, ge"
  },
  {
    "path": "minillm/data_types.py",
    "chars": 3589,
    "preview": "from dataclasses import dataclass\nfrom typing import Iterable\nfrom torchtyping import TensorType\n\n\n@dataclass\nclass Prom"
  },
  {
    "path": "minillm/losses.py",
    "chars": 9137,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom typing import Optional, Tuple\nfrom torchtyping i"
  },
  {
    "path": "minillm/model.py",
    "chars": 722,
    "preview": "import torch.nn as nn\nfrom transformers import (\n    AutoConfig,)\n\nfrom utils import get_model\n\n\nclass PPOModel(nn.Modul"
  },
  {
    "path": "minillm/pipelines.py",
    "chars": 8225,
    "preview": "import os\nimport json\nimport torch\nimport random\nimport numpy as np\nfrom torch.utils.data import DataLoader, Distributed"
  },
  {
    "path": "minillm/reward.py",
    "chars": 2872,
    "preview": "import torch\nfrom transformers import (\n    AutoModelForCausalLM,\n    AutoTokenizer,\n    mpu)\n\n\nclass Reward():\n    def "
  },
  {
    "path": "minillm/sampler.py",
    "chars": 8436,
    "preview": "import torch\nimport os\n\nfrom .data_types import PromptBatch, PPORLElement\nfrom .pipelines import PPOPipeline\nfrom .train"
  },
  {
    "path": "minillm/storages.py",
    "chars": 5807,
    "preview": "import json\nimport os\nimport time\nfrom abc import abstractmethod\nfrom typing import Any, Callable, Iterable\n\nimport torc"
  },
  {
    "path": "minillm/trainer.py",
    "chars": 26448,
    "preview": "import json\nimport os\nimport deepspeed\nfrom time import time\nfrom typing import Optional, Tuple\nfrom collections import "
  },
  {
    "path": "minillm/utils.py",
    "chars": 4793,
    "preview": "import math\nfrom enum import Enum\nfrom numbers import Number\nfrom typing import Tuple\n\nimport torch\nimport torch.nn.func"
  },
  {
    "path": "rouge_metric.py",
    "chars": 4902,
    "preview": "import string\nimport json\nimport os\nimport argparse\nfrom rouge_score import rouge_scorer\nfrom transformers import AutoTo"
  },
  {
    "path": "scripts/gpt2/distillm/train_0.1B_1.5B.sh",
    "chars": 2554,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/distillm/train_0.3B_1.5B.sh",
    "chars": 2593,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/distillm/train_0.7B_1.5B.sh",
    "chars": 2592,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/eval/eval_main_dolly.sh",
    "chars": 1639,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
  },
  {
    "path": "scripts/gpt2/eval/eval_main_self_inst.sh",
    "chars": 1647,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
  },
  {
    "path": "scripts/gpt2/eval/eval_main_sinst.sh",
    "chars": 1669,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
  },
  {
    "path": "scripts/gpt2/eval/eval_main_uinst.sh",
    "chars": 1672,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
  },
  {
    "path": "scripts/gpt2/eval/eval_main_vicuna.sh",
    "chars": 1641,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
  },
  {
    "path": "scripts/gpt2/eval/run_eval.sh",
    "chars": 802,
    "preview": "#!/bin/bash\n\nMASTER_PORT=2040\nDEVICE=${1}\nckpt=${2}\n\nfor seed in 10 20 30 40 50\ndo\n    CUDA_VISIBLE_DEVICES=${DEVICE} ba"
  },
  {
    "path": "scripts/gpt2/gkd/gkd_base.sh",
    "chars": 2343,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/gkd/gkd_large.sh",
    "chars": 2349,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/gkd/gkd_medium.sh",
    "chars": 2351,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/imitkd/imitkd_base.sh",
    "chars": 2424,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/imitkd/imitkd_large.sh",
    "chars": 2426,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/imitkd/imitkd_medium.sh",
    "chars": 2428,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/init/init_base.sh",
    "chars": 2056,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/init/init_large.sh",
    "chars": 2062,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/init/init_medium.sh",
    "chars": 2064,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/kd/kd_base.sh",
    "chars": 2300,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/kd/kd_large.sh",
    "chars": 2320,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/kd/kd_medium.sh",
    "chars": 2323,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/minillm/train_base_xl.sh",
    "chars": 2516,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/minillm/train_large_xl.sh",
    "chars": 2517,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/minillm/train_medium_xl.sh",
    "chars": 2519,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/seqkd/seqkd_base.sh",
    "chars": 2359,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/seqkd/seqkd_large.sh",
    "chars": 2325,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/seqkd/seqkd_medium.sh",
    "chars": 2328,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/sft/sft_base.sh",
    "chars": 2057,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/sft/sft_large.sh",
    "chars": 2063,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/sft/sft_medium.sh",
    "chars": 2065,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/sft/sft_xlarge.sh",
    "chars": 2112,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/tools/generate_data_seqkd.sh",
    "chars": 1568,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/gpt2/tools/process_data_dolly.sh",
    "chars": 824,
    "preview": "BASE_PATH=${1}\n\nexport TF_CPP_MIN_LOG_LEVEL=3\n\n# only prompt for MiniLLM train\nPYTHONPATH=${BASE_PATH} python3 ${BASE_PA"
  },
  {
    "path": "scripts/gpt2/tools/process_data_pretrain.sh",
    "chars": 412,
    "preview": "BASE_PATH=${1}\n\nMAX_LENGTH=512\n\nPYTHONPATH=${BASE_PATH} python3 ${BASE_PATH}/tools/process_data_pretrain.py \\\n    --data"
  },
  {
    "path": "scripts/gpt2/tools/process_pseudo_data_seqkd.sh",
    "chars": 740,
    "preview": "BASE_PATH=${1}\n\nexport TF_CPP_MIN_LOG_LEVEL=3\n\nPYTHONPATH=${BASE_PATH} python3 ${BASE_PATH}/tools/process_data_dolly.py "
  },
  {
    "path": "scripts/openllama2/distillm/train_3B_7B_teacher_lora.sh",
    "chars": 2970,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/openllama2/eval/eval_main_dolly_lora.sh",
    "chars": 1980,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
  },
  {
    "path": "scripts/openllama2/eval/eval_main_self_inst_lora.sh",
    "chars": 1988,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
  },
  {
    "path": "scripts/openllama2/eval/eval_main_sinst_lora.sh",
    "chars": 2010,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
  },
  {
    "path": "scripts/openllama2/eval/eval_main_uinst_lora.sh",
    "chars": 2013,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
  },
  {
    "path": "scripts/openllama2/eval/eval_main_vicuna_lora.sh",
    "chars": 1982,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
  },
  {
    "path": "scripts/openllama2/eval/run_eval.sh",
    "chars": 935,
    "preview": "#!/bin/bash\n\nMASTER_PORT=2040\nDEVICE=${1}\nckpt=${2}\n\n# dolly eval\nfor seed in 10 20 30 40 50\ndo\n    CUDA_VISIBLE_DEVICES"
  },
  {
    "path": "scripts/openllama2/gkd/gkd_3B_7B_teacher_lora.sh",
    "chars": 2937,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/openllama2/imitkd/imitkd_3B_7B_teacher_lora.sh",
    "chars": 2943,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/openllama2/init/sft_3B_lora.sh",
    "chars": 2164,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/openllama2/kd/kd_3B_7B_teacher_lora.sh",
    "chars": 2657,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/openllama2/minillm/train_3B_7B_lora.sh",
    "chars": 3026,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/openllama2/seqkd/seqkd_3B_7B_teacher_lora.sh",
    "chars": 2658,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/openllama2/sft/sft_3B_lora.sh",
    "chars": 2149,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/openllama2/sft/sft_7B_lora.sh",
    "chars": 2149,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/openllama2/tools/generate_data_seqkd.sh",
    "chars": 1804,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/openllama2/tools/process_data_dolly.sh",
    "chars": 842,
    "preview": "BASE_PATH=${1}\n\nexport TF_CPP_MIN_LOG_LEVEL=3\n\n# only prompt for MiniLLM train\nPYTHONPATH=${BASE_PATH} python3 ${BASE_PA"
  },
  {
    "path": "scripts/openllama2/tools/process_data_pretrain.sh",
    "chars": 421,
    "preview": "BASE_PATH=${1}\n\nMAX_LENGTH=512\n\nPYTHONPATH=${BASE_PATH} python3 ${BASE_PATH}/tools/process_data_pretrain.py \\\n    --data"
  },
  {
    "path": "scripts/openllama2/tools/process_pseudo_data_seqkd.sh",
    "chars": 789,
    "preview": "BASE_PATH=${1}\n\nexport TF_CPP_MIN_LOG_LEVEL=3\n\nPYTHONPATH=${BASE_PATH} python3 ${BASE_PATH}/tools/process_data_dolly.py "
  },
  {
    "path": "scripts/opt/distillm/train_0.1B_2.7B.sh",
    "chars": 2627,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/distillm/train_0.3B_2.7B.sh",
    "chars": 2627,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/distillm/train_1.3B_2.7B.sh",
    "chars": 2627,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/eval/eval_main_dolly.sh",
    "chars": 1717,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
  },
  {
    "path": "scripts/opt/eval/eval_main_self_inst.sh",
    "chars": 1725,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
  },
  {
    "path": "scripts/opt/eval/eval_main_sinst.sh",
    "chars": 1751,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
  },
  {
    "path": "scripts/opt/eval/eval_main_uinst.sh",
    "chars": 1750,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
  },
  {
    "path": "scripts/opt/eval/eval_main_vicuna.sh",
    "chars": 1719,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-1}\n\nDISTRIBUTED_ARGS=\"-"
  },
  {
    "path": "scripts/opt/eval/run_eval.sh",
    "chars": 801,
    "preview": "#!/bin/bash\n\nMASTER_PORT=2040\nDEVICE=${1}\nckpt=${2}\n\n# dolly eval\nfor seed in $SEED\ndo\n    CUDA_VISIBLE_DEVICES=${DEVICE"
  },
  {
    "path": "scripts/opt/gkd/gkd_0.1B_2.7B.sh",
    "chars": 2476,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/gkd/gkd_0.3B_2.7B.sh",
    "chars": 2477,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/gkd/gkd_1.3B_2.7B.sh",
    "chars": 2477,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/imitkd/imitkd_0.1B_2.7B.sh",
    "chars": 2479,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/imitkd/imitkd_0.3B_2.7B.sh",
    "chars": 2480,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/imitkd/imitkd_1.3B_2.7B.sh",
    "chars": 2480,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/init/init_0.1B.sh",
    "chars": 2101,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/init/init_0.3B.sh",
    "chars": 2065,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/init/init_1.3B.sh",
    "chars": 2100,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/kd/kd_0.1B_2.7B.sh",
    "chars": 2403,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/kd/kd_0.3B_2.7B.sh",
    "chars": 2403,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/kd/kd_1.3B_2.7B.sh",
    "chars": 2403,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/minillm/train_0.1B_2.7B.sh",
    "chars": 2528,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/minillm/train_0.3B_2.7B.sh",
    "chars": 2536,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/minillm/train_1.3B_2.7B.sh",
    "chars": 2536,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/seqkd/seqkd_0.1B_2.7B.sh",
    "chars": 2413,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/seqkd/seqkd_0.3B_2.7B.sh",
    "chars": 2413,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/seqkd/seqkd_1.3B_2.7B.sh",
    "chars": 2413,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/sft/sft_0.1B.sh",
    "chars": 2099,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/sft/sft_0.3B.sh",
    "chars": 2101,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/sft/sft_1.3B.sh",
    "chars": 2143,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/sft/sft_2.7B.sh",
    "chars": 2100,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2012}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/tools/generate_data_seqkd.sh",
    "chars": 1595,
    "preview": "#! /bin/bash\n\nMASTER_ADDR=localhost\nMASTER_PORT=${2-2113}\nNNODES=1\nNODE_RANK=0\nGPUS_PER_NODE=${3-16}\n\nDISTRIBUTED_ARGS=\""
  },
  {
    "path": "scripts/opt/tools/process_data_dolly.sh",
    "chars": 818,
    "preview": "BASE_PATH=${1}\n\nexport TF_CPP_MIN_LOG_LEVEL=3\n\n# only prompt for MiniLLM train\nPYTHONPATH=${BASE_PATH} python3 ${BASE_PA"
  },
  {
    "path": "scripts/opt/tools/process_data_pretrain.sh",
    "chars": 408,
    "preview": "BASE_PATH=${1}\n\nMAX_LENGTH=512\n\nPYTHONPATH=${BASE_PATH} python3 ${BASE_PATH}/tools/process_data_pretrain.py \\\n    --data"
  },
  {
    "path": "scripts/opt/tools/process_pseudo_data_seqkd.sh",
    "chars": 743,
    "preview": "BASE_PATH=${1}\n\nexport TF_CPP_MIN_LOG_LEVEL=3\n\nPYTHONPATH=${BASE_PATH} python3 ${BASE_PATH}/tools/process_data_dolly.py "
  },
  {
    "path": "tools/convert_mp.py",
    "chars": 4251,
    "preview": "#coding:utf-8\nimport torch\nimport argparse\nimport os\nfrom transformers import AutoModelForCausalLM\nfrom transformers imp"
  },
  {
    "path": "tools/get_openwebtext.py",
    "chars": 343,
    "preview": "import datasets\nimport os\nimport re\n\ndataset = datasets.load_dataset('openwebtext', split='train')\n\nos.makedirs(\"data/op"
  },
  {
    "path": "tools/process_data_dolly.py",
    "chars": 6422,
    "preview": "import multiprocessing\nimport os\nimport time\nimport torch\nimport json\nimport sys\nfrom numerize.numerize import numerize\n"
  },
  {
    "path": "tools/process_data_pretrain.py",
    "chars": 3784,
    "preview": "import multiprocessing\nimport os\nimport time\nimport torch\nimport sys\nfrom numerize.numerize import numerize\nimport numpy"
  },
  {
    "path": "train_minillm.py",
    "chars": 2608,
    "preview": "import torch\nimport os\nimport json\nimport torch.distributed as dist\nfrom accelerate import init_empty_weights\n\nfrom tran"
  },
  {
    "path": "utils.py",
    "chars": 8111,
    "preview": "from typing import Dict\nimport numpy as np\nimport os\nimport time\nimport torch.distributed as dist\nfrom torch.distributed"
  }
]

About this extraction

This page contains the full source code of the jongwooko/distillm GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 130 files (390.6 KB), approximately 117.1k tokens, and a symbol index with 283 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!