Showing preview only (653K chars total). Download the full file or copy to clipboard to get everything.
Repository: liangyuwang/zo2
Branch: main
Commit: 4bca25c2cd69
Files: 101
Total size: 619.2 KB
Directory structure:
gitextract_z35rb3tr/
├── LICENSE
├── README.md
├── env.yml
├── example/
│ ├── demo/
│ │ └── train_zo2_with_hf_opt.py
│ └── mezo_runner/
│ ├── README.md
│ ├── metrics.py
│ ├── mezo.sh
│ ├── run.py
│ ├── tasks.py
│ ├── templates.py
│ └── utils.py
├── requirements.txt
├── script/
│ ├── add-copyright.py
│ └── clear-pycache.sh
├── setup.py
├── test/
│ ├── README.md
│ └── mezo_sgd/
│ ├── hf_gpt/
│ │ └── trainer.py
│ ├── hf_llama/
│ │ └── trainer.py
│ ├── hf_opt/
│ │ ├── record_zo2_memory.sh
│ │ ├── record_zo2_speed.sh
│ │ ├── test_acc.py
│ │ ├── test_acc_eval.sh
│ │ ├── test_acc_train.sh
│ │ ├── test_memory.py
│ │ ├── test_memory_eval.sh
│ │ ├── test_memory_train.sh
│ │ ├── test_scheduler_acc_eval.sh
│ │ ├── test_scheduler_acc_train.sh
│ │ ├── test_speed.py
│ │ ├── test_speed_eval.sh
│ │ ├── test_speed_train.sh
│ │ └── utils.py
│ ├── hf_qwen3/
│ │ ├── record_zo2_memory.sh
│ │ ├── record_zo2_speed.sh
│ │ ├── test_acc.py
│ │ ├── test_acc_eval.sh
│ │ ├── test_acc_train.sh
│ │ ├── test_memory.py
│ │ ├── test_memory_train.sh
│ │ ├── test_speed.py
│ │ ├── test_speed_train.sh
│ │ └── utils.py
│ └── nanogpt/
│ ├── record_zo2_memory.sh
│ ├── record_zo2_speed.sh
│ ├── test_acc.py
│ ├── test_acc_eval.sh
│ ├── test_acc_train.sh
│ ├── test_memory.py
│ ├── test_memory_eval.sh
│ ├── test_memory_train.sh
│ ├── test_speed.py
│ ├── test_speed_eval.sh
│ ├── test_speed_train.sh
│ └── utils.py
├── tutorial/
│ ├── README.md
│ ├── colab.ipynb
│ ├── demo.ipynb
│ ├── huggingface.ipynb
│ └── nanogpt.ipynb
└── zo2/
├── README.md
├── __init__.py
├── config/
│ ├── __init__.py
│ └── mezo_sgd.py
├── model/
│ ├── __init__.py
│ ├── base.py
│ ├── huggingface/
│ │ ├── __init__.py
│ │ ├── gpt/
│ │ │ └── mezo_sgd/
│ │ │ ├── zo.py
│ │ │ └── zo2.py
│ │ ├── llama/
│ │ │ └── mezo_sgd/
│ │ │ ├── zo.py
│ │ │ └── zo2.py
│ │ ├── opt/
│ │ │ ├── __init__.py
│ │ │ └── mezo_sgd/
│ │ │ ├── __init__.py
│ │ │ ├── utils.py
│ │ │ ├── zo.py
│ │ │ └── zo2.py
│ │ ├── qwen3/
│ │ │ ├── __init__.py
│ │ │ └── mezo_sgd/
│ │ │ ├── __init__.py
│ │ │ ├── utils.py
│ │ │ ├── zo.py
│ │ │ └── zo2.py
│ │ └── zo_init.py
│ └── nanogpt/
│ ├── __init__.py
│ ├── mezo_sgd/
│ │ ├── __init__.py
│ │ ├── zo.py
│ │ └── zo2.py
│ └── model.py
├── optimizer/
│ ├── __init__.py
│ ├── base.py
│ └── mezo_sgd/
│ ├── __init__.py
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── com.py
│ │ └── comm.py
│ ├── zo.py
│ └── zo2.py
├── trainer/
│ ├── __init__.py
│ ├── hf_transformers/
│ │ ├── __init__.py
│ │ └── trainer.py
│ └── hf_trl/
│ ├── __init__.py
│ └── sft_trainer.py
└── utils/
├── __init__.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# ZO2 (Zeroth-Order Offloading): Full Parameter Fine-Tuning 175B LLMs with 18GB GPU Memory
[](https://arxiv.org/abs/2503.12668)
[](https://github.com/liangyuwang/zo2/blob/main/LICENSE)
[](https://gitdiagram.com/liangyuwang/zo2)
[](https://deepwiki.com/liangyuwang/zo2)
<!-- <a target="_blank" href="https://colab.research.google.com/github/liangyuwang/zo2/blob/main/tutorial/colab.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a> -->
👋 Welcome! **ZO2** is an innovative framework specifically designed to enhance the fine-tuning of large language models (LLMs) using **zeroth-order (ZO)** optimization techniques and advanced **offloading** technologies. This framework is particularly tailored for setups with limited GPU memory (e.g. fine-tune **[OPT-175B](https://arxiv.org/abs/2205.01068)** with just **18GB GPU memory**), enabling the fine-tuning of models that were previously unmanageable due to hardware constraints.
- The table below displays the GPU memory usage for various OPT model sizes when fine-tuned using the ZO2 framework:
| OPT Models | 1.3B | 2.7B | 6.7B | 13B | 30B | 66B | 175B |
| :-----------------------: | :------: | :------: | :------: | :------: | :------: | :-------: | :-----------------: |
| **GPU memory (GB)** | `3.75` | `4.14` | `4.99` | `6.18` | `8.86` | `12.07` | **`18.04`** |
- [Install](#️installation) the package and execute the following test to see the memory usage:
```shell
bash test/mezo_sgd/hf_opt/record_zo2_memory.sh
```
## 📰 News
- 16/07/2025: ZO2 is accepted by [COLM](https://colmweb.org/index.html).
- 02/05/2025: Added support for [Qwen3](https://qwenlm.github.io/blog/qwen3/). You can now fully fine-tune the [32B version](https://huggingface.co/Qwen/Qwen3-32B-FP8) with just 6GB GPU memory using ZO2. Please refer to our [example](example/mezo_runner/).
- 01/05/2025: We upgraded the environment and dependencies to align with the latest `transformers==4.51.3`.
- 06/03/2025: We have open-sourced ZO2!
## 💡 Key Features
- **Optimized ZO CPU Offloading**: ZO2 leverages `zeroth-order (ZO)` methods to efficiently use `CPU offloading`, avoiding redundant data transfers and significantly reducing GPU memory demands. This allows for handling large-scale models on hardware with limited GPU resources.
- **Dynamic Scheduling**: Incorporates a high-performance scheduler to optimize the `computation-communication overlap`, enhancing GPU utilization and preventing training delays.
- **Capability for Very Large Models**: Enables the fine-tuning of extraordinarily large models, such as those with over `175 billion parameters`, on single GPUs with as little as `18GB` of memory, previously impossible with traditional methods.
- **Empirical Validation**: ZO2 has demonstrated through rigorous testing that it can efficiently fine-tune massive models `without extra time costs or accuracy losses`, confirming its effectiveness for large-scale model training.
## ⚙️ Installation
We offer two installation options, and you only need to use one of them to install ZO2:
1. To experiment with our examples, tutorials, or tests, follow these steps to set up the ZO2 environment:
```shell
git clone https://github.com/liangyuwang/zo2.git
cd zo2/
conda env create -f env.yml
conda activate zo2
```
2. If you want to use ZO2 as a package in your own code, you can install it directly in your Python environment.
Before installing the ZO2 package, ensure you have the required dependencies:
- [PyTorch](https://pytorch.org/get-started/locally/) >= 2.4.0, CUDA >= 12.1
Once the dependencies are installed, you can install the ZO2 package using pip:
```shell
pip install git+https://github.com/liangyuwang/zo2.git
```
## 🛠️ Usage
We utilize the [OPT](https://arxiv.org/abs/2205.01068) models and [MeZO-SGD](https://arxiv.org/abs/2305.17333) as examples. For additional information, please refer to the section on [Supported Models and ZO methods](#-supported-models-zo-methods-and-tasks-support).
### 1. Using [MeZO-Runner](example/mezo_runner/) to Evaluate Fine-tuning Tasks
Before running the following commands, please ensure that you have cloned the entire project. If you [installed](#️installation) ZO2 using option 2, you will need to run "git clone https://github.com/liangyuwang/zo2.git" to obtain the complete project, then navigate to the zo2 folder by "cd zo2".
```shell
cd example/mezo_runner/
export CUDA_VISIBLE_DEVICES=0
MODEL=facebook/opt-2.7b TASK=SST2 MODE=ft LR=1e-7 EPS=1e-3 STEPS=20000 EVAL_STEPS=4000 bash mezo.sh
```
### 2. Fine-Tuning HF Models with ZOTrainer / ZOSFTTrainer [[Trainer](./tutorial/huggingface.ipynb)]
```python
from zo2 import ZOConfig, zo_hf_init
from zo2.trainer.hf_transformers import ZOTrainer
from transformers import TrainingArguments
# Model and optimizer init
zo_config = ZOConfig(method="mezo-sgd", zo2=True, offloading_device='cpu', working_device='cuda', lr=1e-5)
with zo_hf_init(zo_config):
from transformers import OPTForCausalLM
model = OPTForCausalLM.from_pretrained("facebook/opt-125m")
model.zo_init(zo_config)
training_args = TrainingArguments("test-trainer")
trainer = ZOTrainer(
model,
args = training_args,
train_dataset=..., # get training dataset
eval_dataset=..., # get eval dataset
data_collator=..., # get data_collator
tokenizer=..., # use suitable tokenizer
...
)
trainer.train()
```
### 3. Train HF Models with Custom Training Loop [[demo](./tutorial/demo.ipynb)]
```python
from zo2 import ZOConfig, zo_hf_init
# Model and optimizer init
zo_config = ZOConfig(method="mezo-sgd", zo2=True, offloading_device='cpu', working_device='cuda', lr=1e-5)
with zo_hf_init(zo_config):
from transformers import OPTForCausalLM
model = OPTForCausalLM.from_pretrained("facebook/opt-125m")
model.zo_init(zo_config)
# Training loop
for i in range(max_training_step):
# Train
training_input_ids, training_labels = ... # get training data batch
model.zo_train()
loss = model(input_ids=training_input_ids, labels=training_labels)
# Evaluate
eval_input_ids, eval_labels = ... # get eval data batch
model.zo_eval()
output = model(input_ids=eval_input_ids, labels=eval_labels)
```
## ✨ Tutorial
Please refer to [tutorial](./tutorial/).
## 🤖 Supported Models, ZO methods, and Tasks
- **Models**:
* [NanoGPT](https://github.com/karpathy/build-nanogpt/blob/master/train_gpt2.py) (mainly for idea evaluation)
* [Transformers](https://github.com/huggingface/transformers): [OPT](https://arxiv.org/abs/2205.01068)
- **ZO methods**:
* [MeZO-SGD](https://arxiv.org/abs/2305.17333)
- **Tasks**: Please refer to [MeZO-Runner](example/mezo_runner/)
## 🧪 Test
Please refer to [test](./test/).
## 🧭 Roadmap
- [ ] Support more models like LLaMA and DeepSeek
- [ ] Support more ZO methods
- [ ] Support more offloading strategies (Disk offloading)
## 🚶 Contributing
Feel free to submit issues and pull requests to improve the project!
## 📲 Contact
* Liangyu Wang: liangyu.wang@kaust.edu.sa
## 📖 BibTeX
```
@article{wang2025zo2,
title={ZO2: Scalable Zeroth-Order Fine-Tuning for Extremely Large Language Models with Limited GPU Memory},
author={Wang, Liangyu and Ren, Jie and Xu, Hang and Wang, Junxiao and Xie, Huanyi and Keyes, David E and Wang, Di},
journal={arXiv preprint arXiv:2503.12668},
year={2025}
}
```
================================================
FILE: env.yml
================================================
name: zo2
channels:
- pytorch
- nvidia
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- blas=1.0=mkl
- brotli-python=1.0.9=py311h6a678d5_8
- bzip2=1.0.8=h5eee18b_6
- ca-certificates=2024.7.2=h06a4308_0
- certifi=2024.7.4=py311h06a4308_0
- charset-normalizer=3.3.2=pyhd3eb1b0_0
- cuda-cudart=12.1.105=0
- cuda-cupti=12.1.105=0
- cuda-libraries=12.1.0=0
- cuda-nvrtc=12.1.105=0
- cuda-nvtx=12.1.105=0
- cuda-opencl=12.6.37=0
- cuda-runtime=12.1.0=0
- cuda-version=12.6=3
- ffmpeg=4.3=hf484d3e_0
- filelock=3.13.1=py311h06a4308_0
- freetype=2.12.1=h4a9f257_0
- gmp=6.2.1=h295c915_3
- gmpy2=2.1.2=py311hc9b5ff0_0
- gnutls=3.6.15=he1e5248_0
- idna=3.7=py311h06a4308_0
- intel-openmp=2023.1.0=hdb19cb5_46306
- jinja2=3.1.4=py311h06a4308_0
- jpeg=9e=h5eee18b_3
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.38=h1181459_1
- lerc=3.0=h295c915_0
- libcublas=12.1.0.26=0
- libcufft=11.0.2.4=0
- libcufile=1.11.0.15=0
- libcurand=10.3.7.37=0
- libcusolver=11.4.4.55=0
- libcusparse=12.0.2.55=0
- libdeflate=1.17=h5eee18b_1
- libffi=3.4.4=h6a678d5_1
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libiconv=1.16=h5eee18b_3
- libidn2=2.3.4=h5eee18b_0
- libjpeg-turbo=2.0.0=h9bf148f_0
- libnpp=12.0.2.50=0
- libnvjitlink=12.1.105=0
- libnvjpeg=12.1.1.14=0
- libpng=1.6.39=h5eee18b_0
- libstdcxx-ng=11.2.0=h1234567_1
- libtasn1=4.19.0=h5eee18b_0
- libtiff=4.5.1=h6a678d5_0
- libunistring=0.9.10=h27cfd23_0
- libuuid=1.41.5=h5eee18b_0
- libwebp-base=1.3.2=h5eee18b_0
- llvm-openmp=14.0.6=h9e868ea_0
- lz4-c=1.9.4=h6a678d5_1
- markupsafe=2.1.3=py311h5eee18b_0
- mkl=2023.1.0=h213fc3f_46344
- mkl-service=2.4.0=py311h5eee18b_1
- mkl_fft=1.3.8=py311h5eee18b_0
- mkl_random=1.2.4=py311hdb19cb5_0
- mpc=1.1.0=h10f8cd9_1
- mpfr=4.0.2=hb69a4c5_1
- mpmath=1.3.0=py311h06a4308_0
- ncurses=6.4=h6a678d5_0
- nettle=3.7.3=hbbd107a_1
- networkx=3.3=py311h06a4308_0
- numpy=1.26.4=py311h08b1b3b_0
- numpy-base=1.26.4=py311hf175353_0
- openh264=2.1.1=h4ff587b_0
- openjpeg=2.5.2=he7f1fd0_0
- openssl=3.0.14=h5eee18b_0
- pillow=10.4.0=py311h5eee18b_0
- pip=24.0=py311h06a4308_0
- pysocks=1.7.1=py311h06a4308_0
- python=3.11.9=h955ad1f_0
- pytorch=2.4.0=py3.11_cuda12.1_cudnn9.1.0_0
- pytorch-cuda=12.1=ha16c6d3_5
- pytorch-mutex=1.0=cuda
- pyyaml=6.0.1=py311h5eee18b_0
- readline=8.2=h5eee18b_0
- requests=2.32.3=py311h06a4308_0
- setuptools=72.1.0=py311h06a4308_0
- sqlite=3.45.3=h5eee18b_0
- sympy=1.12=py311h06a4308_0
- tbb=2021.8.0=hdb19cb5_0
- tk=8.6.14=h39e8969_0
- torchaudio=2.4.0=py311_cu121
- torchtriton=3.0.0=py311
- torchvision=0.19.0=py311_cu121
- typing_extensions=4.11.0=py311h06a4308_0
- urllib3=2.2.2=py311h06a4308_0
- wheel=0.43.0=py311h06a4308_0
- xz=5.4.6=h5eee18b_1
- yaml=0.2.5=h7b6447c_0
- zlib=1.2.13=h5eee18b_1
- zstd=1.5.5=hc292b87_2
- pip:
- accelerate==1.6.0
- aiohappyeyeballs==2.3.5
- aiohttp==3.10.3
- aiosignal==1.3.1
- attrs==24.2.0
- datasets==3.5.1
- dill==0.3.8
- frozenlist==1.4.1
- fsspec==2024.5.0
- huggingface-hub==0.30.2
- joblib==1.4.2
- markdown-it-py==3.0.0
- mdurl==0.1.2
- multidict==6.0.5
- multiprocess==0.70.16
- nvidia-ml-py==12.570.86
- opt-einsum==3.3.0
- packaging==24.1
- pandas==2.2.2
- psutil==6.0.0
- pyarrow==17.0.0
- pyarrow-hotfix==0.6
- pygments==2.19.1
- python-dateutil==2.9.0.post0
- pytz==2024.1
- regex==2024.7.24
- rich==14.0.0
- safetensors==0.5.3
- scikit-learn==1.5.1
- scipy==1.14.0
- six==1.16.0
- threadpoolctl==3.5.0
- tokenizers==0.21.1
- tqdm==4.66.5
- transformers==4.51.3
- trl==0.17.0
- tzdata==2024.1
- xxhash==3.4.1
- yarl==1.9.4
================================================
FILE: example/demo/train_zo2_with_hf_opt.py
================================================
# Copyright (c) 2025 liangyuwang
# Licensed under the Apache License, Version 2.0
import sys
sys.path.append("../zo2")
import argparse
from tqdm.auto import tqdm
import torch
from transformers import AutoTokenizer
from zo2 import (
ZOConfig,
zo_hf_init,
)
from zo2.utils.utils import seed_everything
# Hyper
args = argparse.ArgumentParser()
args.add_argument("--zo_method", type=str, default="zo2")
args.add_argument("--eval", action="store_true")
args.add_argument("--model_name", type=str, default="facebook/opt-2.7b")
args.add_argument("--verbose", action="store_true")
args.add_argument("--max_steps", type=int, default=100)
args.add_argument("--lr", type=float, default=1e-5)
args.add_argument("--weight_decay", type=float, default=1e-1)
args.add_argument("--zo_eps", type=float, default=1e-3)
args.add_argument("--seed", type=int, default=42)
args.add_argument("--offloading_device", type=str, default="cpu")
args.add_argument("--working_device", type=str, default="cuda:0")
# For inference
args.add_argument("--use_cache", action="store_true")
args.add_argument("--max_new_tokens", type=int, default=50)
args.add_argument("--temperature", type=float, default=1.0)
args = args.parse_args()
seed_everything(args.seed)
# ZO steps
zo_config = ZOConfig(
method="mezo-sgd",
zo2=args.zo_method=="zo2",
lr=args.lr,
weight_decay=args.weight_decay,
eps=args.zo_eps,
offloading_device=args.offloading_device,
working_device=args.working_device,
)
# Load ZO model
with zo_hf_init(zo_config):
from transformers import OPTForCausalLM
model = OPTForCausalLM.from_pretrained(args.model_name)
model.zo_init(zo_config)
if args.zo_method != "zo2": model = model.to(args.working_device)
print(f"Check if zo2 init correctly: {hasattr(model, 'zo_training')}")
# Prepare some data
dataset = """
What is ZO2?
ZO2 is an innovative framework specifically designed to enhance the fine-tuning of large language models (LLMs) using zeroth-order (ZO) optimization techniques and advanced offloading technologies.
This framework is particularly tailored for setups with limited GPU memory, enabling the fine-tuning of models that were previously unmanageable due to hardware constraints.
As the scale of Large Language Models (LLMs) continues to grow, reaching parameter counts in the hundreds of billions, managing GPU memory resources effectively becomes crucial.
Efficient GPU memory management is crucial not only because it directly influences model performance and training speed, but also because GPU memory is both expensive and limited in quantity.
However, this creates a significant challenge in handling ever-larger models within the physical constraints of current hardware technologies.
CPU offloading has become a crucial technique for overcoming this challenge.
It involves transferring computations and data from the GPU to the CPU, specifically targeting data or parameters that are less frequently accessed.
By offloading these inactive tensors of the neural network, CPU offloading effectively alleviates the memory and computational pressures on GPUs.
While CPU offloading has been commonly applied in inference to manage memory-intensive tasks, its application in training, especially fine-tuning, remains less explored.
Recently, some works have tried to introduce CPU offloading into LLM training.
However, they are typically constrained by the capabilities of first-order optimizers such as SGD and Adaptive Moment Estimation (AdamW), and limited GPU memory, restricting large-scale model scalability on single GPU systems.
Using first-order optimizers introduces inefficiencies in CPU offloading: Multiple communication operations during the training of LLMs necessitate offloading the same data twice—once for each pass.
This redundancy not only doubles the communication volume between the CPU and GPU but also introduces significant latency due to repetitive data transfers.
Furthermore, both parameters and activations are required in the backward pass to complete gradient computations.
This means that parameters and activation values must be offloaded during each forward pass and re-uploaded to the GPU for the backward pass, increasing the volume of data transferred, which severely impacts training throughput.
On the other hand, zeroth-order (ZO) methods offer a novel approach to fine-tuning LLMs.
These methods utilize dual forward passes to estimate parameter gradients and subsequently update parameters.
This approach eliminates the traditional reliance on backward passes, thereby streamlining the training process by significantly reducing the number of computational steps required.
Based on these observations, we conjecture that ZO's architecture is particularly well-suited for CPU offloading strategies.
By eliminating backward passes and the need to store activation values, it can significantly reduce GPU memory demands through efficient parameter offloading.
However, despite these advantages, ZO training via CPU offloading introduces new challenges, particularly in the realm of CPU-to-GPU communication.
Transferring parameters between the CPU and GPU, which is crucial for maintaining gradient computation and model updates, becomes a critical bottleneck.
Although ZO methods inherently extend computation times because of the dual forward passes, potentially allowing for better overlap between computation and communication, there remain significant inefficiencies.
The necessity to upload parameters to the GPU for upcoming computations introduces a large volume of communications. To tackle the inefficiencies highlighted, we introduce ZO2, a novel framework specifically designed for ZO fine-tuning in LLMs with CPU offloading.
This framework utilizes the unique dual forward pass architecture of ZO methods to optimize interactions between CPU and GPU, significantly enhancing both computational and communication efficiency.
By building a high-performance dynamic scheduler, ZO2 achieves substantial overlaps in communication and computation.
These innovations make it feasible to fine-tune extremely large models, such as the OPT-175B, with over 175 billion parameters, on a single GPU equipped with just 18GB of memory usage—a capability previously unattainable with conventional methods.
Additionally, our efficient framework operates without any extra time cost and decreases in accuracy compared to standard ZO methodologies."""
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
data_batch = tokenizer(dataset, add_special_tokens=True, return_tensors='pt').input_ids.to(args.working_device)
T = min(data_batch.shape[1] - 1, model.config.max_position_embeddings)
print(f"Fine-tuning model {args.model_name} with {T} tokens dataset: \n{dataset}")
# Training loop
for i in tqdm(range(args.max_steps)):
# train
model.zo_train()
loss = model(input_ids=data_batch, labels=data_batch)
# eval
if args.eval:
if i==0:
tqdm.write("Warning: please notice that ZO2 does not optimize the evaluation, so it may be very slow.")
model.zo_eval()
output = model(input_ids=data_batch, labels=data_batch)
res = "Iteration {}, train loss: {}, projected grad: {}, eval loss: {}"
tqdm.write(res.format(i, loss, model.opt.projected_grad, output["loss"]))
else:
res = "Iteration {}, train loss: {}, projected grad: {}"
tqdm.write(res.format(i, loss, model.opt.projected_grad))
# inference
print("Doing inference...")
print("Warning: please notice that ZO2 does not optimize the inference, so it may be very slow.")
model.zo_eval()
prompt = "What is ZO2 and how ZO2 enhance the fine-tuning of large language models?"
inputs = tokenizer(prompt, return_tensors='pt').to(args.working_device)
inputs = {"input_ids": inputs.input_ids}
for _ in tqdm(range(args.max_new_tokens)):
outputs = model(**inputs, return_dict=True)
next_token_logits = outputs.logits[:, -1, :]
if args.temperature == 1.0:
next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
else:
scaled_logits = next_token_logits / args.temperature
probs = torch.nn.functional.softmax(scaled_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
inputs = torch.cat([inputs["input_ids"], next_token], dim=-1)
generated_text = tokenizer.decode(inputs[0])
inputs = {"input_ids": inputs}
print(f"Question: {prompt}")
print(f"Response: {generated_text[len(prompt)+4:]}...")
================================================
FILE: example/mezo_runner/README.md
================================================
# Example: Apply MeZO on LLMs
Modified from [MeZO](https://github.com/princeton-nlp/MeZO/blob/main/large_models/README.md)
## Usage
Use `run.py` for all functions (zero-shot/MeZO):
```bash
python run.py {ARGUMENTS}
```
Please read [run.py](./run.py) for a complete list of arguments. We introduce some of the most important ones below.
* `--num_train`: Number of training examples.
* `--num_dev`: Number of validation examples.
* `--num_test`: Number of testing examples.
* `--model_name`: HuggingFace model name or path.
* `--task_name`: Task name.
* `--trainer`: can be `none` (zero-shot) or `zo` (MeZO).
* `--train_as_classification`: turn this on for classification tasks (Cross Entropy over likelihood of each class' label words). Otherwise it is LM-style teacher forcing.
* `--zo_eps`: MeZO hyperparameter epsilon.
* `--zo_method`: choose zeroth-order methods.
* `--zo_mode`: can be `zo` (on device) or `zo2` (offloading).
* `--offloading_device`: offloading device.
* `--working_device`: main working device.
Example:
1. MeZO (full-parameter fine-tuning)
```bash
# You can adjust the following model size and other hyperparameters.
# OPT-2.7B
MODEL=facebook/opt-2.7b TASK=SST2 MODE=ft LR=1e-7 EPS=1e-3 STEPS=20000 EVAL_STEPS=4000 bash mezo.sh
# Qwen3-1.7B
MODEL=Qwen/Qwen3-1.7B TASK=SST2 MODE=ft LR=1e-7 EPS=1e-3 STEPS=20000 EVAL_STEPS=4000 bash mezo.sh
```
## Supported Tasks (See [tasks.py](./tasks.py))
- **SST2**
- **Copa**
- **BoolQ**
- **MultiRC**
- **CB**
- **WIC**
- **WSC**
- **ReCoRD**
- **RTE**
- **SQuAD**
- **DROP**
================================================
FILE: example/mezo_runner/metrics.py
================================================
# Copyright (c) 2025 liangyuwang
# Licensed under the Apache License, Version 2.0
import numpy as np
import collections
import re
import string
from collections import Counter
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def calculate_metric(predictions, metric_name):
if metric_name == "accuracy":
if isinstance(predictions[0].correct_candidate, list):
return np.mean([pred.predicted_candidate in pred.correct_candidate for pred in predictions])
else:
return np.mean([pred.correct_candidate == pred.predicted_candidate for pred in predictions])
elif metric_name == "em":
# For question answering
return np.mean([any([normalize_answer(ans) == normalize_answer(pred.predicted_candidate) for ans in pred.correct_candidate]) for pred in predictions])
elif metric_name == "f1":
# For question answering
f1 = []
for pred in predictions:
all_f1s = []
if pred.correct_candidate[0] == "CANNOTANSWER" or pred.correct_candidate[0] == "no answer":
f1.append(int(normalize_answer(pred.correct_candidate[0]) == normalize_answer(pred.predicted_candidate)))
else:
for ans in pred.correct_candidate:
prediction_tokens = normalize_answer(pred.predicted_candidate).split()
ground_truth_tokens = normalize_answer(ans).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
all_f1s.append(0)
else:
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
all_f1s.append((2 * precision * recall) / (precision + recall))
f1.append(max(all_f1s))
return np.mean(f1)
def f1(pred, gold):
"""
This separate F1 function is used as non-differentiable metric for SQuAD
"""
if gold[0] == "CANNOTANSWER" or gold[0] == "no answer":
return int(normalize_answer(gold[0]) == normalize_answer(pred))
else:
all_f1s = []
for ans in gold:
prediction_tokens = normalize_answer(pred).split()
ground_truth_tokens = normalize_answer(ans).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
all_f1s.append(0)
else:
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
all_f1s.append((2 * precision * recall) / (precision + recall))
return np.max(all_f1s)
================================================
FILE: example/mezo_runner/mezo.sh
================================================
MODEL=${MODEL:-facebook/opt-1.3b}
MODEL_NAME=(${MODEL//\// })
MODEL_NAME="${MODEL_NAME[-1]}"
BS=${BS:-16}
LR=${LR:-1e-5}
EPS=${EPS:-1e-3}
SEED=${SEED:-0}
TRAIN=${TRAIN:-1000}
DEV=${DEV:-500}
EVAL=${EVAL:-1000}
STEPS=${STEPS:-20000}
EVAL_STEPS=${EVAL_STEPS:-4000}
MODE=${MODE:-ft}
EXTRA_ARGS=""
if [ "$MODE" == "prefix" ]; then
EXTRA_ARGS="--prefix_tuning --num_prefix 5 --no_reparam --prefix_init_by_real_act"
elif [ "$MODE" == "lora" ]; then
EXTRA_ARGS="--lora"
fi
TAG=mezo-$MODE-$STEPS-$BS-$LR-$EPS-$SEED
TASK_ARGS=""
case $TASK in
# For Copa, ReCoRD, SQuAD, DROP, we set --train_as_classification False; for others, set this flag to True
CB) # It has <1000 training examples. Only use 100 for dev
DEV=100
;;
Copa) # It has <1000 training examples. Only use 100 for dev
DEV=100
TASK_ARGS="--train_as_classification False"
;;
ReCoRD)
TASK_ARGS="--train_as_classification False"
;;
DROP)
TASK_ARGS="--train_as_classification False"
;;
SQuAD)
TASK_ARGS="--train_as_classification False"
;;
esac
echo $TAG
echo "BS: $BS"
echo "LR: $LR"
echo "EPS: $EPS"
echo "SEED: $SEED"
echo "TRAIN/EVAL STEPS: $STEPS/$EVAL_STEPS"
echo "MODE: $MODE"
echo "Extra args: $EXTRA_ARGS $TASK_ARGS"
python run.py \
--model_name $MODEL \
--task_name $TASK \
--output_dir result/$TASK-${MODEL_NAME}-$TAG --tag $TAG --train_set_seed $SEED --num_train $TRAIN --num_dev $DEV --num_eval $EVAL --logging_steps 10 \
--max_steps $STEPS \
--trainer zo --load_float16 \
--learning_rate $LR --zo_eps $EPS --per_device_train_batch_size $BS --lr_scheduler_type "constant" \
--load_best_model_at_end --eval_strategy steps --save_strategy steps --save_total_limit 1 \
--eval_steps $EVAL_STEPS --save_steps $EVAL_STEPS \
--train_as_classification \
$EXTRA_ARGS \
$TASK_ARGS \
"$@"
================================================
FILE: example/mezo_runner/run.py
================================================
# Copyright (c) 2025 liangyuwang
# Licensed under the Apache License, Version 2.0
"""
Modified from https://github.com/princeton-nlp/MeZO/blob/main/large_models/run.py
"""
import sys
sys.path.append("../../../zo2")
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
import argparse
import time
import tasks
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, Trainer, HfArgumentParser, Trainer, TrainingArguments, DataCollatorWithPadding, DataCollatorForTokenClassification
from typing import Union, Optional
import torch
from torch.nn.parameter import Parameter
import numpy as np
from dataclasses import dataclass, is_dataclass, asdict
from tqdm import tqdm
from tasks import get_task
import json
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from metrics import calculate_metric
from utils import *
import random
from zo2.trainer.hf_transformers.trainer import ZOTrainer
from zo2 import zo_hf_init, ZOConfig
@dataclass
class OurArguments(TrainingArguments):
# dataset and sampling strategy
task_name: str = "SST2" # task name should match the string before Dataset in the Dataset class name. We support the following task_name: SST2, RTE, CB, BoolQ, WSC, WIC, MultiRC, Copa, ReCoRD, SQuAD, DROP
# Number of examples
num_train: int = 0 # ICL mode: number of demonstrations; training mode: number of training samples
num_dev: int = None # (only enabled with training) number of development samples
num_eval: int = None # number of evaluation samples
num_train_sets: int = None # how many sets of training samples/demos to sample; if None and train_set_seed is None, then we will sample one set for each evaluation sample
train_set_seed: int = None # designated seed to sample training samples/demos
result_file: str = None # file name for saving performance; if None, then use the task name, model name, and config
# Model loading
model_name: str = "facebook/opt-125m" # HuggingFace model name
load_float16: bool = False # load model parameters as float16
load_bfloat16: bool = False # load model parameters as bfloat16
load_int8: bool = False # load model parameters as int8
max_length: int = 2048 # max length the model can take
no_auto_device: bool = False # do not load model by auto device; should turn this on when using FSDP
# Calibration
sfc: bool = False # whether to use SFC calibration
icl_sfc: bool = False # whether to use SFC calibration for ICL samples
# Training
trainer: str = "none"
## options
## - none: no training -- for zero-shot or in-context learning (ICL)
## - regular: regular huggingface trainer -- for fine-tuning
## - zo: zeroth-order (MeZO) training
only_train_option: bool = True # whether to only train the option part of the input
train_as_classification: bool = False # take the log likelihood of all options and train as classification
# MeZO
zo_eps: float = 1e-3 # eps in MeZO
# Prefix tuning
prefix_tuning: bool = False # whether to use prefix tuning
num_prefix: int = 5 # number of prefixes to use
no_reparam: bool = True # do not use reparameterization trick
prefix_init_by_real_act: bool = True # initialize prefix by real activations of random words
# LoRA
lora: bool = False # whether to use LoRA
lora_alpha: int = 16 # alpha in LoRA
lora_r: int = 8 # r in LoRA
# Generation
sampling: bool = False # whether to use sampling
temperature: float = 1.0 # temperature for generation
num_beams: int = 1 # number of beams for generation
top_k: int = None # top-k for generation
top_p: float = 0.95 # top-p for generation
max_new_tokens: int = 50 # max number of new tokens to generate
eos_token: str = "\n" # end of sentence token
# Saving
save_model: bool = False # whether to save the model
no_eval: bool = False # whether to skip evaluation
tag: str = "" # saving tag
# Linear probing
linear_probing: bool = False # whether to do linear probing
lp_early_stopping: bool = False # whether to do early stopping in linear probing
head_tuning: bool = False # head tuning: only tune the LM head
# Untie emb/lm_head weights
untie_emb: bool = False # untie the embeddings and LM head
# Display
verbose: bool = False # verbose output
# Non-diff objective
non_diff: bool = False # use non-differentiable objective (only support F1 for SQuAD for now)
# Auto saving when interrupted
save_on_interrupt: bool = False # save model when interrupted (useful for long training)
# ZO2 added -> ZO2 configs
zo_method: str = "mezo-sgd"
zo_mode: str = "zo2"
offloading_device: str = "cpu"
working_device: str = "cuda:0"
def parse_args():
parser = argparse.ArgumentParser()
parser = HfArgumentParser(OurArguments)
args = parser.parse_args_into_dataclasses()[0]
print(args)
return args
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
class Framework:
def __init__(self, args, task):
self.args = args
self.task = task
self.model, self.tokenizer = self.load_model()
def load_model(self):
"""
Load HuggingFace models
"""
with count_time("Loading model with FP%d" % (16 if self.args.load_float16 else 32)):
free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3)
config = AutoConfig.from_pretrained(self.args.model_name)
if self.args.untie_emb:
# Untie embeddings/LM head
logger.warn("Untie embeddings and LM head")
config.tie_word_embeddings = False
# if self.args.head_tuning:
# # Head tuning
# from ht_opt import OPTForCausalLM
# model = OPTForCausalLM.from_pretrained(
# self.args.model_name,
# config=config,
# )
# elif self.args.no_auto_device:
# # No auto device (use for FSDP)
# model = AutoModelForCausalLM.from_pretrained(
# self.args.model_name,
# config=config,
# )
# else:
# # Auto device loading
# torch_dtype = torch.float32
# if self.args.load_float16:
# torch_dtype = torch.float16
# elif self.args.load_bfloat16:
# torch_dtype = torch.bfloat16
# model = AutoModelForCausalLM.from_pretrained(
# self.args.model_name,
# config=config,
# device_map='auto',
# torch_dtype=torch_dtype,
# max_memory={i: f'{free_in_GB-5}GB' for i in range(torch.cuda.device_count())},
# load_in_8bit=self.args.load_int8,
# )
# ZO2 added -> init ZO2 model
torch_dtype = torch.float32
if self.args.load_float16:
torch_dtype = torch.float16
elif self.args.load_bfloat16:
torch_dtype = torch.bfloat16
# Set up ZO configuration
self.zo_config = ZOConfig(
method="mezo-sgd",
zo2=(self.args.zo_mode == "zo2"),
lr=self.args.learning_rate,
weight_decay=self.args.weight_decay,
eps=self.args.zo_eps,
offloading_device=self.args.offloading_device,
working_device=self.args.working_device,
)
# Initialize model within zo_hf_init context
with zo_hf_init(self.zo_config):
if "opt" in self.args.model_name:
from transformers import OPTForCausalLM
model = OPTForCausalLM.from_pretrained(
self.args.model_name,
config=config,
torch_dtype=torch_dtype,
max_memory={i: f'{free_in_GB-5}GB' for i in range(torch.cuda.device_count())},
load_in_8bit=self.args.load_int8,
)
elif "Qwen3" in self.args.model_name:
from transformers import Qwen3ForCausalLM
model = Qwen3ForCausalLM.from_pretrained(
self.args.model_name,
config=config,
torch_dtype=torch_dtype,
max_memory={i: f'{free_in_GB-5}GB' for i in range(torch.cuda.device_count())},
load_in_8bit=self.args.load_int8,
)
model.zo_init(self.zo_config)
logger.info(f"Check if zo2 init correctly: {hasattr(model, 'zo_training')}")
# If using a method other than zo2, move model to working device
if self.args.zo_method != "zo2":
model = model.to(self.args.working_device)
model.eval()
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(self.args.model_name, use_fast=False)
# HF tokenizer bug fix
if "opt" in self.args.model_name:
tokenizer.bos_token_id = 0
if "llama" in self.args.model_name:
# LLaMA padding token
tokenizer.pad_token_id = 0 # technically <unk>
if "Qwen3" in self.args.model_name:
# LLaMA padding token
tokenizer.add_bos_token = False
# Prefix tuning/LoRA
if self.args.prefix_tuning:
# from prefix import PrefixTuning
# PrefixTuning(model, num_prefix=self.args.num_prefix, reparam=not self.args.no_reparam, float16=self.args.load_float16, init_by_real_act=self.args.prefix_init_by_real_act)
raise NotImplementedError
if self.args.lora:
# from lora import LoRA
# LoRA(model, r=self.args.lora_r, alpha=self.args.lora_alpha, float16=self.args.load_float16)
raise NotImplementedError
if self.args.head_tuning:
# if model.config.model_type == "opt":
# head_name = "lm_head" if self.args.untie_emb else "embed_tokens"
# else:
# raise NotImplementedError
# for n, p in model.named_parameters():
# if head_name not in n:
# p.requires_grad = False
# else:
# logger.info(f"Only tuning {n}")
raise NotImplementedError
return model, tokenizer
def forward(self, input_ids, option_len=None, generation=False):
"""
Given input_ids and the length of the option, return the log-likelihood of each token in the option.
For generation tasks, return the generated text.
This function is only for inference
"""
input_ids = torch.tensor([input_ids]).to(self.model.device)
if generation:
args = self.args
# Autoregressive generation
outputs = self.model.generate(
input_ids, do_sample=args.sampling, temperature=args.temperature,
num_beams=args.num_beams, top_p=args.top_p, top_k=args.top_k, max_new_tokens=min(args.max_new_tokens, args.max_length - input_ids.size(1)),
num_return_sequences=1, eos_token_id=[self.tokenizer.encode(args.eos_token, add_special_tokens=False)[-1], self.tokenizer.eos_token_id],
)
# For generation, directly return the text output
output_text = self.tokenizer.decode(outputs[0][input_ids.size(1):], skip_special_tokens=True).strip()
return output_text
else:
with torch.inference_mode():
self.model.eval()
logits = self.model(input_ids=input_ids).logits
labels = input_ids[0, 1:]
logits = logits[0, :-1]
log_probs = F.log_softmax(logits, dim=-1)
selected_log_probs = log_probs[torch.arange(len(labels)).to(labels.device), labels]
selected_log_probs = selected_log_probs.cpu().detach()
# Only return the option (candidate) part
return selected_log_probs[-option_len:]
def one_step_pred(self, train_samples, eval_sample, verbose=False):
"""
Return the prediction on the eval sample. In ICL, use train_samples as demonstrations
"""
verbose = verbose or self.args.verbose
if verbose:
logger.info("========= Example =========")
logger.info(f"Candidate: {eval_sample.candidates}")
logger.info(f"Correct candidate: {eval_sample.correct_candidate}")
# Encode (add prompt and tokenize) the sample; if multiple-choice/classification, encode all candidates (options)
encoded_candidates, option_lens = encode_prompt(
self.task, self.task.get_template(), train_samples, eval_sample, self.tokenizer, max_length=self.args.max_length,
generation=self.task.generation, max_new_tokens=self.args.max_new_tokens
)
# Calibration
if self.args.sfc or self.args.icl_sfc:
sfc_encoded_candidates, sfc_option_lens = encode_prompt(self.task, self.task.get_template(),
train_samples, eval_sample, self.tokenizer, max_length=self.args.max_length,
sfc=self.args.sfc, icl_sfc=self.args.icl_sfc, generation=self.task.generation,
max_new_tokens=self.args.max_new_tokens
)
outputs = []
if self.task.generation:
# For generation tasks, return the autoregressively-generated text
output_text = self.forward(encoded_candidates[0], generation=True)
if verbose:
logger.info("=== Prompt ===")
logger.info(self.tokenizer.decode(encoded_candidates[0]))
logger.info(f"Output: {output_text}")
return Prediction(correct_candidate=eval_sample.correct_candidate, predicted_candidate=output_text)
else:
# For classification/multiple-choice, calculate the probabilities of all candidates
for candidate_id, encoded_candidate in enumerate(encoded_candidates):
selected_log_probs = self.forward(encoded_candidate, option_len=option_lens[candidate_id])
if verbose:
if candidate_id == 0:
logger.info("=== Candidate %d ===" % candidate_id)
logger.info(self.tokenizer.decode(encoded_candidate))
else:
logger.info("=== Candidate %d (without context)===" % candidate_id)
logger.info(self.tokenizer.decode(encoded_candidate).split(self.task.train_sep)[-1])
logger.info(f"Log probabilities of the option tokens: {selected_log_probs}")
if self.args.sfc or self.args.icl_sfc:
sfc_selected_log_probs = self.forward(sfc_encoded_candidates[candidate_id], option_len=sfc_option_lens[candidate_id])
if verbose:
logger.info("=== Candidate %d (without context) SFC ===" % candidate_id)
logger.info(self.tokenizer.decode(sfc_encoded_candidates[candidate_id]).split(self.task.train_sep)[-1])
logger.info(f"Log probabilities of the option tokens: {sfc_selected_log_probs}")
outputs.append({"log_probs": selected_log_probs, "sfc_log_probs": sfc_selected_log_probs if self.args.sfc or self.args.icl_sfc else None})
if self.args.sfc or self.args.icl_sfc:
# Calibrated probabilities (surface form competition; https://arxiv.org/pdf/2104.08315.pdf)
# log p(candidate | input) = log p_lm(candidate | input) - log p_lm(candidate | sfc prompt)
scores = [x['log_probs'].sum().item() - x['sfc_log_probs'].sum().item() for x in outputs]
else:
# (Default) length-normalized log probabilities
# log p(candidate | input) = log p_lm(candidate | input) / |candidate #tokens|
scores = [x['log_probs'].mean().item() for x in outputs]
if verbose:
logger.info(f"Prediction scores: {scores}")
if isinstance(eval_sample.correct_candidate, list):
# For some datasets there are multiple correct answers
correct_candidate_id = [eval_sample.candidates.index(c) for c in eval_sample.correct_candidate]
else:
correct_candidate_id = eval_sample.candidates.index(eval_sample.correct_candidate)
return Prediction(correct_candidate=correct_candidate_id, predicted_candidate=int(np.argmax(scores)))
def evaluate(self, train_samples, eval_samples, one_train_set_per_eval_sample=False):
"""
Evaluate function. If one_train_set_per_eval_sample is True, then each eval sample has its own training (demonstration) set.
"""
if one_train_set_per_eval_sample:
logger.info(f"There are {len(eval_samples)} validation samples and one train set per eval sample")
else:
logger.info(f"There are {len(train_samples)} training samples and {len(eval_samples)} validation samples")
# Prediction loop
predictions = []
for eval_id, eval_sample in enumerate(tqdm(eval_samples)):
predictions.append(
self.one_step_pred(train_samples[eval_id] if one_train_set_per_eval_sample else train_samples, eval_sample, verbose=(eval_id < 3))
)
# Calculate metrics
metric_name = getattr(self.task, "metric_name", "accuracy")
metrics = {metric_name: calculate_metric(predictions, metric_name)}
return metrics
def train(self, train_samples, eval_samples):
"""
Training function
"""
# Set tokenizer to left padding (so that all the options are right aligned)
self.tokenizer.padding_side = "left"
class HFDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def _convert(samples):
"""
Convert samples to HF-compatible dataset
"""
data = []
for sample in samples:
encoded_candidates, option_lens = encode_prompt(
self.task, self.task.get_template(), [], sample, self.tokenizer,
max_length=self.args.max_length, generation=self.task.generation, generation_with_gold=True,
max_new_tokens=self.args.max_new_tokens
)
if self.task.generation:
correct_candidate_id = 0
elif isinstance(sample.correct_candidate, list):
correct_candidate_id = sample.candidates.index(sample.correct_candidate[0])
else:
correct_candidate_id = sample.candidates.index(sample.correct_candidate)
if self.args.non_diff:
# For non-differentiable objective, there is no teacher forcing thus the
# current answer part is removed
encoded_candidates[correct_candidate_id] = encoded_candidates[correct_candidate_id][:-option_lens[correct_candidate_id]]
if self.args.train_as_classification:
# For classification, we provide the label as the correct candidate id
data.append([{"input_ids": encoded_candidates[_i], "labels": correct_candidate_id, "option_len": option_lens[_i], "num_options": len(sample.candidates)} for _i in range(len(encoded_candidates))])
elif self.args.only_train_option:
# Otherwise, it is just LM-style teacher forcing
if self.args.non_diff:
# For non-differentiable objective, we need to provide the gold answer to calculate F1/acc
data.append({"input_ids": encoded_candidates[correct_candidate_id], "labels": encoded_candidates[correct_candidate_id], "option_len": option_lens[correct_candidate_id], "gold": sample.correct_candidate})
else:
data.append({"input_ids": encoded_candidates[correct_candidate_id], "labels": encoded_candidates[correct_candidate_id], "option_len": option_lens[correct_candidate_id]})
else:
data.append({"input_ids": encoded_candidates[correct_candidate_id], "labels": encoded_candidates[correct_candidate_id]})
return data
with count_time("Tokenizing training samples"):
train_dataset = HFDataset(_convert(train_samples))
eval_dataset = HFDataset(_convert(eval_samples))
if self.args.only_train_option and not self.args.non_diff:
# # If --only_train_option and not with a non-differentiable objective, we wrap the forward function
# self.model.original_forward = self.model.forward
# self.model.forward = forward_wrap_with_option_len.__get__(self.model, type(self.model))
# ZO2 added -> register custom loss functions
self.model.zo_custom_train_loss_fn = custom_loss_fn_with_option_len
self.model.zo_custom_eval_loss_fn = custom_loss_fn_with_option_len
if self.args.non_diff:
collator = NondiffCollator
else:
collator = DataCollatorForTokenClassification
# ZO2 added ->
trainer = ZOTrainer(
model=self.model,
args=self.args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=self.tokenizer,
data_collator=DataCollatorWithPaddingAndNesting(self.tokenizer, pad_to_multiple_of=8) if self.args.train_as_classification else collator(self.tokenizer, pad_to_multiple_of=8),
)
if self.args.save_on_interrupt:
trainer.add_callback(SIGUSR1Callback())
# Resume training from a last checkpoint
last_checkpoint = None
from transformers.trainer_utils import get_last_checkpoint
if os.path.isdir(self.args.output_dir) and not self.args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(self.args.output_dir)
if last_checkpoint is not None and self.args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
if self.args.resume_from_checkpoint is not None:
last_checkpoint = self.args.resume_from_checkpoint
trainer.train(resume_from_checkpoint=last_checkpoint)
# Explicitly save the model
if self.args.save_model:
logger.warn("Save model..")
trainer.save_model()
# FSDP compatibility
self.model = trainer.model
# Reset the forward function for evaluation
if self.args.only_train_option and not self.args.non_diff:
# if type(self.model) == FSDP:
# logger.info("This is an FSDP model now. Be careful when assigning back the original forward function")
# self.model._fsdp_wrapped_module.forward = self.model._fsdp_wrapped_module.original_forward
# else:
# self.model.forward = self.model.original_forward
# ZO2 added -> remove the custom loss functions for evaluation
self.model.zo_custom_train_loss_fn = None
self.model.zo_custom_eval_loss_fn = None
def result_file_tag(args):
"""
Get the result file tag
"""
save_model_name = args.model_name.split("/")[-1]
sfc_tag = "-sfc" if args.sfc else ""
icl_sfc_tag = "-icl_sfc" if args.icl_sfc else ""
sample_eval_tag = "-sampleeval%d" % args.num_eval if args.num_eval is not None else ""
sample_train_tag = "-ntrain%d" % args.num_train if args.num_train > 0 else ""
sample_dev_tag = "-ndev%d" % args.num_dev if args.num_dev is not None else ""
customized_tag = f"-{args.tag}" if len(args.tag) > 0 else ""
return f"{args.task_name}-{save_model_name}" + sfc_tag + icl_sfc_tag + sample_eval_tag + sample_train_tag + sample_dev_tag + customized_tag
def main():
args = parse_args()
set_seed(args.seed)
task = get_task(args.task_name)
train_sets = task.sample_train_sets(num_train=args.num_train, num_dev=args.num_dev, num_eval=args.num_eval, num_train_sets=args.num_train_sets, seed=args.train_set_seed)
# Initialize trainer and load model
framework = Framework(args, task)
if args.train_set_seed is not None or args.num_train_sets is not None:
# Eval samples share one (or multiple) training set(s)
for train_set_id, train_samples in enumerate(train_sets):
train_set_seed = train_set_id if args.train_set_seed is None else args.train_set_seed
# Sample eval samples
if args.num_eval is not None:
eval_samples = task.sample_subset(data_split="valid", seed=train_set_seed, num=args.num_eval)
else:
eval_samples = task.valid_samples
if args.trainer != "none":
if args.num_dev is not None:
# Dev samples
dev_samples = train_samples[-args.num_dev:]
train_samples = train_samples[:-args.num_dev]
else:
dev_samples = None
# Training
framework.train(train_samples, dev_samples if dev_samples is not None else eval_samples)
if not args.no_eval:
metrics = framework.evaluate([], eval_samples) # No in-context learning if there is training
if dev_samples is not None:
dev_metrics = framework.evaluate([], dev_samples)
for m in dev_metrics:
metrics["dev_" + m] = dev_metrics[m]
else:
assert args.num_dev is None
# Zero-shot / in-context learning
metrics = framework.evaluate(train_samples, eval_samples)
if not args.no_eval:
logger.info("===== Train set %d =====" % train_set_seed)
logger.info(metrics)
if args.local_rank <= 0:
write_metrics_to_file(metrics, "result/" + result_file_tag(args) + f"-trainset{train_set_id}.json" if args.result_file is None else args.result_file)
else:
# For each eval sample, there is a training set. no training is allowed
# This is for in-context learning (ICL)
assert args.trainer == "none"
if args.num_eval is not None:
eval_samples = task.sample_subset(data_split="valid", seed=0, num=args.num_eval)
else:
eval_samples = task.valid_samples
metrics = framework.evaluate(train_sets, eval_samples, one_train_set_per_eval_sample=True)
logger.info(metrics)
if args.local_rank <= 0:
write_metrics_to_file(metrics, "result/" + result_file_tag(args) + "-onetrainpereval.json" if args.result_file is None else args.result_file)
if __name__ == "__main__":
main()
================================================
FILE: example/mezo_runner/tasks.py
================================================
# Copyright (c) 2025 liangyuwang
# Licensed under the Apache License, Version 2.0
"""
Copied https://github.com/princeton-nlp/MeZO/blob/main/large_models/tasks.py
"""
from templates import *
from utils import temp_seed
import json
import os
from datasets import load_dataset
from dataclasses import dataclass
from typing import List, Union
import string
import random
import datasets
import sys
import numpy as np
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def get_task(task_name):
aa = task_name.split("__")
if len(aa) == 2:
task_group, subtask = aa
else:
task_group = aa[0]
subtask = None
class_ = getattr(sys.modules[__name__], f"{task_group}Dataset")
instance = class_(subtask)
return instance
@dataclass
class Sample:
id: int = None
data: dict = None
correct_candidate: Union[str, List[str]] = None
candidates: List[str] = None
class Dataset:
mixed_set = False
train_sep = "\n\n"
generation = False # whether this is a generation task
def __init__(self, subtask=None, **kwargs) -> None:
self.subtask = subtask
def get_task_name(self):
return self.subtask
def load_dataset():
raise NotImplementedError
def get_template(self, template_version=0):
templates = {0: Template}
return templates[template_version]
def build_sample(self, example):
return
def sample_train_sets(self, num_train=32, num_dev=None, num_eval=None, num_train_sets=None, seed=None):
if seed is not None:
# one train/demo set using the designated seed
seeds = [seed]
elif num_train_sets is not None:
# num_train_sets train/demo sets
seeds = list(range(num_train_sets))
else:
# one train/demo set per evaluation sample
assert num_dev is None # not supported
len_valid_samples = len(self.samples["valid"]) if num_eval is None else num_eval
with temp_seed(0):
seeds = np.random.randint(0, 10000, len_valid_samples)
train_samples = []
for i, set_seed in enumerate(seeds):
if self.mixed_set:
raise NotImplementedError
train_samples.append(self.sample_subset(data_split="valid", seed=set_seed, num=num_train, exclude=i))
else:
if num_dev is not None:
train_samples.append(self.sample_subset(data_split="train", seed=set_seed, num=num_train+num_dev)) # dev set is included at the end of train set
if num_train + num_dev > len(self.samples["train"]):
logger.warn("num_train + num_dev > available training examples")
else:
train_samples.append(self.sample_subset(data_split="train", seed=set_seed, num=num_train))
if num_dev is not None:
logger.info(f"Sample train set {len(train_samples[-1])}/{len(self.samples['train'])}")
logger.info(f"... including dev set {num_dev} samples")
return train_samples
def sample_subset(self, data_split="train", seed=0, num=100, exclude=None):
with temp_seed(seed):
samples = self.samples[data_split]
lens = len(samples)
index = np.random.permutation(lens).tolist()[:num if exclude is None else num+1]
if exclude is not None and exclude in index:
index.remove(exclude)
else:
index = index[:num]
return [samples[i] for i in index]
@property
def valid_samples(self):
return self.samples["valid"]
class SST2Dataset(Dataset):
train_sep = "\n\n"
def __init__(self, subtask=None, **kwargs) -> None:
self.load_dataset(subtask, **kwargs)
def load_dataset(self, path, **kwargs):
d = load_dataset('glue', 'sst2')
train_d = d["train"]
validation_d = d["validation"]
train_samples = [self.build_sample(example) for example in train_d]
valid_samples = [self.build_sample(example) for example in validation_d]
self.samples = {"train": train_samples, "valid": valid_samples}
# for generative tasks, candidates are []
def build_sample(self, example):
label = int(example["label"])
return Sample(id=example["idx"], data=example, correct_candidate=label, candidates=[0, 1])
def get_template(self, template_version=0):
return {0: SST2Template}[template_version]()
class CopaDataset(Dataset):
train_sep = "\n\n"
mixed_set = False
def __init__(self, subtask=None, **kwargs) -> None:
self.load_dataset(subtask, **kwargs)
def load_dataset(self, path, **kwargs):
train_examples = load_dataset('super_glue', "copa")["train"]
valid_examples = load_dataset('super_glue', "copa")["validation"]
train_samples = [self.build_sample(example) for example in train_examples]
valid_samples = [self.build_sample(example) for example in valid_examples]
self.samples = {"train": train_samples, "valid": valid_samples}
# for generative tasks, candidates are []
def build_sample(self, example):
sample = \
Sample(
id=example["idx"],
data=example,
candidates=[example["choice1"], example["choice2"]],
correct_candidate=example[f"choice{example['label'] + 1}"],
)
return sample
def get_template(self, template_version=0):
return {0: CopaTemplate}[template_version]()
class BoolQDataset(Dataset):
def __init__(self, subtask=None, **kwargs) -> None:
self.load_dataset(subtask, **kwargs)
def load_dataset(self, path, **kwargs):
d = load_dataset("boolq")
train_set = d["train"]
valid_set = d["validation"]
train_samples = [self.build_sample(example) for example in train_set]
valid_samples = [self.build_sample(example) for example in valid_set]
self.samples = {"train": train_samples, "valid": valid_samples}
def build_sample(self, example):
sample = \
Sample(
data=example,
candidates=["Yes", "No"],
correct_candidate="Yes" if example["answer"] else "No",
)
return sample
def get_template(self, template_version=2):
return {0: BoolQTemplate, 1: BoolQTemplateV2, 2: BoolQTemplateV3}[template_version]()
class MultiRCDataset(Dataset):
def __init__(self, subtask=None, **kwargs) -> None:
self.load_dataset(subtask, **kwargs)
def load_dataset(self, path, **kwargs):
d = load_dataset("super_glue", "multirc")
train_set = d["train"]
valid_set = d["validation"]
train_samples = [self.build_sample(example) for example in train_set]
valid_samples = [self.build_sample(example) for example in valid_set]
self.samples = {"train": train_samples, "valid": valid_samples}
def build_sample(self, example):
sample = \
Sample(
data=example,
candidates=[0, 1],
correct_candidate=example['label']
)
return sample
def get_template(self, template_version=0):
return {0: MultiRCTemplate}[template_version]()
class CBDataset(Dataset):
def __init__(self, subtask=None, **kwargs) -> None:
self.load_dataset(subtask, **kwargs)
def load_dataset(self, path, **kwargs):
d = load_dataset("super_glue", "cb")
train_set = d["train"]
valid_set = d["validation"]
train_samples = [self.build_sample(example) for example in train_set]
valid_samples = [self.build_sample(example) for example in valid_set]
self.samples = {"train": train_samples, "valid": valid_samples}
def build_sample(self, example):
sample = \
Sample(
data=example,
candidates=[0, 1, 2],
correct_candidate=example['label']
)
return sample
def get_template(self, template_version=0):
return {0: CBTemplate}[template_version]()
class WICDataset(Dataset):
def __init__(self, subtask=None, **kwargs) -> None:
self.load_dataset(subtask, **kwargs)
def load_dataset(self, path, **kwargs):
d = load_dataset("super_glue", "wic")
train_set = d["train"]
valid_set = d["validation"]
train_samples = [self.build_sample(example) for example in train_set]
valid_samples = [self.build_sample(example) for example in valid_set]
self.samples = {"train": train_samples, "valid": valid_samples}
def build_sample(self, example):
sample = \
Sample(
data=example,
candidates=[0, 1],
correct_candidate=example['label']
)
return sample
def get_template(self, template_version=0):
return {0: WICTemplate}[template_version]()
class WSCDataset(Dataset):
def __init__(self, subtask=None, **kwargs) -> None:
self.load_dataset(subtask, **kwargs)
def load_dataset(self, path, **kwargs):
d = load_dataset("super_glue", "wsc.fixed")
train_set = d["train"]
valid_set = d["validation"]
train_samples = [self.build_sample(example) for example in train_set]
valid_samples = [self.build_sample(example) for example in valid_set]
self.samples = {"train": train_samples, "valid": valid_samples}
def build_sample(self, example):
sample = \
Sample(
data=example,
candidates=[0, 1],
correct_candidate=example['label']
)
return sample
def get_template(self, template_version=0):
return {0: WSCTemplate}[template_version]()
class ReCoRDDataset(Dataset):
def __init__(self, subtask=None, **kwargs) -> None:
self.load_dataset(subtask, **kwargs)
def load_dataset(self, path, **kwargs):
d = load_dataset("super_glue", "record")
train_set = d["train"]
valid_set = d["validation"]
train_samples = [self.build_sample(example) for example in train_set]
valid_samples = [self.build_sample(example) for example in valid_set]
self.samples = {"train": train_samples, "valid": valid_samples}
def build_sample(self, example):
sample = \
Sample(
data=example,
candidates=example['entities'],
correct_candidate=example['answers']
)
return sample
def get_template(self, template_version=0):
return {0: ReCoRDTemplateGPT3}[template_version]()
class RTEDataset(Dataset):
def __init__(self, subtask=None, **kwargs) -> None:
self.load_dataset(subtask, **kwargs)
def load_dataset(self, path, **kwargs):
d = load_dataset("super_glue", "rte")
train_set = d["train"]
valid_set = d["validation"]
train_samples = [self.build_sample(example) for example in train_set]
valid_samples = [self.build_sample(example) for example in valid_set]
self.samples = {"train": train_samples, "valid": valid_samples}
def build_sample(self, example):
sample = \
Sample(
data=example,
candidates=[0, 1],
correct_candidate=example['label']
)
return sample
def get_template(self, template_version=0):
return {0: RTETemplate}[template_version]()
class SQuADDataset(Dataset):
metric_name = "f1"
generation = True
def __init__(self, subtask=None, **kwargs) -> None:
self.load_dataset()
def load_dataset(self):
dataset = load_dataset("squad")
train_examples = dataset["train"]
valid_examples = dataset["validation"]
train_samples = [self.build_sample(example, idx) for idx, example in enumerate(train_examples)]
valid_samples = [self.build_sample(example, idx) for idx, example in enumerate(valid_examples)]
self.samples = {"train": train_samples, "valid": valid_samples}
# for generative tasks, candidates are []
def build_sample(self, example, idx):
answers = example['answers']['text']
assert len(answers) > 0
return Sample(
id=idx,
data={
"title": example['title'],
"context": example['context'],
"question": example['question'],
"answers": answers
},
candidates=None,
correct_candidate=answers
)
def get_template(self, template_version=0):
return {0: SQuADv2Template}[template_version]()
class DROPDataset(Dataset):
metric_name = "f1"
generation = True
def __init__(self, subtask=None, **kwargs) -> None:
self.load_dataset()
def load_dataset(self):
dataset = load_dataset("drop")
train_examples = dataset["train"]
valid_examples = dataset["validation"]
train_samples = [self.build_sample(example, idx) for idx, example in enumerate(train_examples)]
valid_samples = [self.build_sample(example, idx) for idx, example in enumerate(valid_examples)]
self.samples = {"train": train_samples, "valid": valid_samples}
# for generative tasks, candidates are []
def build_sample(self, example, idx):
answers = example['answers_spans']['spans']
assert len(answers) > 0
return Sample(
id=idx,
data={
"context": example['passage'],
"question": example['question'],
"answers": answers
},
candidates=None,
correct_candidate=answers
)
def get_template(self, template_version=0):
return {0: DROPTemplate}[template_version]()
================================================
FILE: example/mezo_runner/templates.py
================================================
# Copyright (c) 2025 liangyuwang
# Licensed under the Apache License, Version 2.0
"""
Copied https://github.com/princeton-nlp/MeZO/blob/main/large_models/templates.py
"""
class Template:
def encode(self, sample):
"""
Return prompted version of the example (without the answer/candidate)
"""
raise NotImplementedError
def verbalize(self, sample, candidate):
"""
Return the prompted version of the example (with the answer/candidate)
"""
return candidate
def encode_sfc(self, sample):
"""
Same as encode, but for SFC (calibration) -- this usually means the input is not included
"""
return "<mask>"
def verbalize_sfc(self, sample, candidate):
"""
Same as verbalize, but for SFC (calibration) -- this usually means the input is not included
"""
return candidate
class SST2Template(Template):
verbalizer = {0: "terrible", 1: "great"}
def encode(self, sample):
text = sample.data["sentence"].strip()
return f"{text} It was"
def verbalize(self, sample, candidate):
text = sample.data["sentence"].strip()
return f"{text} It was {self.verbalizer[candidate]}"
def encode_sfc(self, sample):
return f" It was"
def verbalize_sfc(self, sample, candidate):
return f" It was {self.verbalizer[candidate]}"
class CopaTemplate(Template):
capitalization: str = "correct"
effect_conj: str = " so "
cause_conj: str = " because "
def get_conjucture(self, sample):
if sample.data["question"] == "effect":
conjunction = self.effect_conj
elif sample.data["question"] == "cause":
conjunction = self.cause_conj
else:
raise NotImplementedError
return conjunction
def get_prompt(self, sample):
premise = sample.data["premise"].rstrip()
if premise.endswith("."): # TODO Add other scripts with different punctuation
premise = premise[:-1]
conjunction = self.get_conjucture(sample)
prompt = premise + conjunction
if self.capitalization == "upper":
prompt = prompt.upper()
elif self.capitalization == "lower":
prompt = prompt.lower()
return prompt
def encode(self, sample):
prompt = self.get_prompt(sample)
return prompt
def capitalize(self, c):
if self.capitalization == "correct":
words = c.split(" ")
if words[0] != "I":
words[0] = words[0].lower()
return " ".join(words)
elif self.capitalization == "bug":
return c
elif self.capitalization == "upper":
return c.upper()
elif self.capitalization == "lower":
return c.lower()
else:
raise NotImplementedError
def verbalize(self, sample, candidate):
prompt = self.get_prompt(sample)
return prompt + self.capitalize(candidate)
def encode_sfc(self, sample):
conjunction = self.get_conjucture(sample)
return conjunction.strip()
def verbalize_sfc(self, sample, candidate):
conjunction = self.get_conjucture(sample)
sfc_prompt = conjunction.strip() + " " + self.capitalize(candidate)
return sfc_prompt
class BoolQTemplate(Template):
def encode(self, sample):
passage = sample.data["passage"]
question = sample.data["question"]
if not question.endswith("?"):
question = question + "?"
question = question[0].upper() + question[1:]
return f"{passage} {question}"
def verbalize(self, sample, candidate):
passage = sample.data["passage"]
question = sample.data["question"]
if not question.endswith("?"):
question = question + "?"
question = question[0].upper() + question[1:]
return f"{passage} {question} {candidate}"
def encode_sfc(self, sample):
return ""
def verbalize_sfc(self, sample, candidate):
return candidate
class BoolQTemplateV2(Template):
def encode(self, sample):
passage = sample.data["passage"]
question = sample.data["question"]
if not question.endswith("?"):
question = question + "?"
question = question[0].upper() + question[1:]
return f"{passage} {question}\\n\\n"
def verbalize(self, sample, candidate):
passage = sample.data["passage"]
question = sample.data["question"]
if not question.endswith("?"):
question = question + "?"
question = question[0].upper() + question[1:]
return f"{passage} {question}\\n\\n{candidate}"
def encode_sfc(self, sample):
return ""
def verbalize_sfc(self, sample, candidate):
return candidate
class BoolQTemplateV3(Template):
def encode(self, sample):
passage = sample.data["passage"]
question = sample.data["question"]
if not question.endswith("?"):
question = question + "?"
question = question[0].upper() + question[1:]
return f"{passage} {question}\n"
def verbalize(self, sample, candidate):
passage = sample.data["passage"]
question = sample.data["question"]
if not question.endswith("?"):
question = question + "?"
question = question[0].upper() + question[1:]
return f"{passage} {question}\n{candidate}"
def encode_sfc(self, sample):
return ""
def verbalize_sfc(self, sample, candidate):
return candidate
class MultiRCTemplate(Template):
# From PromptSource 1
verbalizer = {0: "No", 1: "Yes"}
def encode(self, sample):
paragraph = sample.data["paragraph"]
question = sample.data["question"]
answer = sample.data["answer"]
return f"{paragraph}\nQuestion: {question}\nI found this answer \"{answer}\". Is that correct? Yes or No?\n"
def verbalize(self, sample, candidate):
paragraph = sample.data["paragraph"]
question = sample.data["question"]
answer = sample.data["answer"]
return f"{paragraph}\nQuestion: {question}\nI found this answer \"{answer}\". Is that correct? Yes or No?\n{self.verbalizer[candidate]}"
def encode_sfc(self, sample):
return f""
def verbalize_sfc(self, sample, candidate):
return f"{self.verbalizer[candidate]}"
class CBTemplate(Template):
# From PromptSource 1
verbalizer = {0: "Yes", 1: "No", 2: "Maybe"}
def encode(self, sample):
premise = sample.data["premise"]
hypothesis = sample.data["hypothesis"]
return f"Suppose {premise} Can we infer that \"{hypothesis}\"? Yes, No, or Maybe?\n"
def verbalize(self, sample, candidate):
premise = sample.data["premise"]
hypothesis = sample.data["hypothesis"]
return f"Suppose {premise} Can we infer that \"{hypothesis}\"? Yes, No, or Maybe?\n{self.verbalizer[candidate]}"
def encode_sfc(self, sample):
return f""
def verbalize_sfc(self, sample, candidate):
return f"{self.verbalizer[candidate]}"
class WICTemplate(Template):
# From PromptSource 1
verbalizer = {0: "No", 1: "Yes"}
def encode(self, sample):
sent1 = sample.data["sentence1"]
sent2 = sample.data["sentence2"]
word = sample.data["word"]
return f"Does the word \"{word}\" have the same meaning in these two sentences? Yes, No?\n{sent1}\n{sent2}\n"
def verbalize(self, sample, candidate):
sent1 = sample.data["sentence1"]
sent2 = sample.data["sentence2"]
word = sample.data["word"]
return f"Does the word \"{word}\" have the same meaning in these two sentences? Yes, No?\n{sent1}\n{sent2}\n{self.verbalizer[candidate]}"
def encode_sfc(self, sample):
return f""
def verbalize_sfc(self, sample, candidate):
return f"{self.verbalizer[candidate]}"
class WSCTemplate(Template):
# From PromptSource 1
verbalizer = {0: "No", 1: "Yes"}
def encode(self, sample):
text = sample.data['text']
span1 = sample.data['span1_text']
span2 = sample.data['span2_text']
return f"{text}\nIn the previous sentence, does the pronoun \"{span2.lower()}\" refer to {span1}? Yes or No?\n"
def verbalize(self, sample, candidate):
text = sample.data['text']
span1 = sample.data['span1_text']
span2 = sample.data['span2_text']
return f"{text}\nIn the previous sentence, does the pronoun \"{span2.lower()}\" refer to {span1}? Yes or No?\n{self.verbalizer[candidate]}"
def encode_sfc(self, sample):
return f""
def verbalize_sfc(self, sample, candidate):
return f"{self.verbalizer[candidate]}"
class ReCoRDTemplate(Template):
# From PromptSource 1 but modified
def encode(self, sample):
passage = sample.data['passage']
query = sample.data['query']
return f"{passage}\n{query}\nQuestion: what is the \"@placeholder\"\nAnswer:"
def verbalize(self, sample, candidate):
passage = sample.data['passage']
query = sample.data['query']
return f"{passage}\n{query}\nQuestion: what is the \"@placeholder\"\nAnswer: {candidate}"
def encode_sfc(self, sample):
return f"Answer:"
def verbalize_sfc(self, sample, candidate):
return f"Answer: {candidate}"
class ReCoRDTemplateGPT3(Template):
# From PromptSource 1 but modified
def encode(self, sample):
passage = sample.data['passage'].replace("@highlight\n", "- ")
return f"{passage}\n-"
def verbalize(self, sample, candidate):
passage = sample.data['passage'].replace("@highlight\n", "- ")
query = sample.data['query'].replace("@placeholder", candidate[0] if isinstance(candidate, list) else candidate)
return f"{passage}\n- {query}"
# passage = sample.data['passage']
# query = sample.data['query']
# return f"{passage}\n{query}\nQuestion: what is the \"@placeholder\"\nAnswer: {candidate}"
def encode_sfc(self, sample):
return f"-"
def verbalize_sfc(self, sample, candidate):
query = sample.data['query'].replace("@placeholder", candidate[0] if isinstance(candidate, list) else candidate)
return f"- {query}"
class RTETemplate(Template):
# From PromptSource 1
verbalizer={0: "Yes", 1: "No"}
def encode(self, sample):
premise = sample.data['premise']
hypothesis = sample.data['hypothesis']
return f"{premise}\nDoes this mean that \"{hypothesis}\" is true? Yes or No?\n"
def verbalize(self, sample, candidate):
premise = sample.data['premise']
hypothesis = sample.data['hypothesis']
return f"{premise}\nDoes this mean that \"{hypothesis}\" is true? Yes or No?\n{self.verbalizer[candidate]}"
def encode_sfc(self, sample):
return f""
def verbalize_sfc(self, sample, candidate):
return f"{self.verbalizer[candidate]}"
class SQuADv2Template(Template):
def encode(self, sample):
question = sample.data['question'].strip()
title = sample.data['title']
context = sample.data['context']
answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one
return f"Title: {title}\nContext: {context}\nQuestion: {question}\nAnswer:"
def verbalize(self, sample, candidate):
question = sample.data['question'].strip()
title = sample.data['title']
context = sample.data['context']
answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one
return f"Title: {title}\nContext: {context}\nQuestion: {question}\nAnswer: {answer}\n"
def encode_sfc(self, sample):
raise NotImplementedError
def verbalize_sfc(self, sample, candidate):
raise NotImplementedError
class DROPTemplate(Template):
def encode(self, sample):
question = sample.data['question'].strip()
# title = sample.data['title']
context = sample.data['context']
answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one
return f"Passage: {context}\nQuestion: {question}\nAnswer:"
def verbalize(self, sample, candidate):
question = sample.data['question'].strip()
# title = sample.data['title']
context = sample.data['context']
answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one
return f"Passage: {context}\nQuestion: {question}\nAnswer: {answer}\n"
def encode_sfc(self, sample):
raise NotImplementedError
def verbalize_sfc(self, sample, candidate):
raise NotImplementedError
================================================
FILE: example/mezo_runner/utils.py
================================================
# Copyright (c) 2025 liangyuwang
# Licensed under the Apache License, Version 2.0
"""
Modified from https://github.com/princeton-nlp/MeZO/blob/main/large_models/utils.py
"""
import json
import os
import contextlib
from typing import Optional, Union
import numpy as np
from dataclasses import dataclass, is_dataclass, asdict
import logging
import time
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from transformers.modeling_outputs import CausalLMOutputWithPast
import torch
from transformers.utils import PaddingStrategy
from transformers import PreTrainedTokenizerBase
from transformers.data.data_collator import DataCollatorMixin
import transformers
from typing import Optional, Union, List, Dict, Any
import signal
from subprocess import call
from collections.abc import Mapping
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
InputDataClass = NewType("InputDataClass", Any)
from dataclasses import dataclass
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
logger = logging.getLogger(__name__)
def custom_loss_fn_with_option_len(self, input_ids, logits, labels, option_len=None, num_options=None):
"""
Modified from below 'forward_wrap_with_option_len'.
"""
loss = None
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
# Here we use input_ids (which should always = labels) bc sometimes labels are correct candidate IDs
shift_labels = torch.clone(input_ids)[..., 1:].contiguous()
shift_labels[shift_labels == self.config.pad_token_id] = -100
# Apply option len (do not calculate loss on the non-option part)
for _i, _len in enumerate(option_len):
shift_labels[_i, :-_len] = -100
# Calculate the loss
loss_fct = CrossEntropyLoss(ignore_index=-100)
if num_options is not None:
# Train as a classification tasks
log_probs = F.log_softmax(shift_logits, dim=-1)
mask = shift_labels != -100 # Option part
shift_labels[~mask] = 0 # So that it doesn't mess up with indexing
selected_log_probs = torch.gather(log_probs, dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(-1) # (bsz x num_options, len)
selected_log_probs = (selected_log_probs * mask).sum(-1) / mask.sum(-1) # (bsz x num_options)
if any([x != num_options[0] for x in num_options]):
# Multi choice tasks with different number of options
loss = 0
start_id = 0
count = 0
while start_id < len(num_options):
end_id = start_id + num_options[start_id]
_logits = selected_log_probs[start_id:end_id].unsqueeze(0) # (1, num_options)
_labels = labels[start_id:end_id][0].unsqueeze(0) # (1)
loss = loss_fct(_logits, _labels) + loss
count += 1
start_id = end_id
loss = loss / count
else:
num_options = num_options[0]
selected_log_probs = selected_log_probs.view(-1, num_options) # (bsz, num_options)
labels = labels.view(-1, num_options)[:, 0] # Labels repeat so we only take the first one
loss = loss_fct(selected_log_probs, labels)
else:
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
return loss
def forward_wrap_with_option_len(self, input_ids=None, labels=None, option_len=None, num_options=None, return_dict=None, **kwargs):
"""
This is to replace the original forward function of Transformer models to enable:
(1) Partial target sequence: loss will only be calculated on part of the sequence
(2) Classification-style training: a classification loss (CE) will be calculated over several options
Input:
- input_ids, labels: same as the original forward function
- option_len: a list of int indicating the option lengths, and loss will be calculated only on the
last option_len tokens
- num_options: a list of int indicating the number of options for each example (this will be #label
words for classification tasks and #choices for multiple choice tasks), and a classification loss
will be calculated.
"""
outputs = self.original_forward(input_ids=input_ids, **kwargs)
if labels is None:
return outputs
logits = outputs.logits
loss = None
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
# Here we use input_ids (which should always = labels) bc sometimes labels are correct candidate IDs
shift_labels = torch.clone(input_ids)[..., 1:].contiguous()
shift_labels[shift_labels == self.config.pad_token_id] = -100
# Apply option len (do not calculate loss on the non-option part)
for _i, _len in enumerate(option_len):
shift_labels[_i, :-_len] = -100
# Calculate the loss
loss_fct = CrossEntropyLoss(ignore_index=-100)
if num_options is not None:
# Train as a classification tasks
log_probs = F.log_softmax(shift_logits, dim=-1)
mask = shift_labels != -100 # Option part
shift_labels[~mask] = 0 # So that it doesn't mess up with indexing
selected_log_probs = torch.gather(log_probs, dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(-1) # (bsz x num_options, len)
selected_log_probs = (selected_log_probs * mask).sum(-1) / mask.sum(-1) # (bsz x num_options)
if any([x != num_options[0] for x in num_options]):
# Multi choice tasks with different number of options
loss = 0
start_id = 0
count = 0
while start_id < len(num_options):
end_id = start_id + num_options[start_id]
_logits = selected_log_probs[start_id:end_id].unsqueeze(0) # (1, num_options)
_labels = labels[start_id:end_id][0].unsqueeze(0) # (1)
loss = loss_fct(_logits, _labels) + loss
count += 1
start_id = end_id
loss = loss / count
else:
num_options = num_options[0]
selected_log_probs = selected_log_probs.view(-1, num_options) # (bsz, num_options)
labels = labels.view(-1, num_options)[:, 0] # Labels repeat so we only take the first one
loss = loss_fct(selected_log_probs, labels)
else:
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def encode_prompt(task, template, train_samples, eval_sample, tokenizer, max_length, sfc=False, icl_sfc=False, generation=False, generation_with_gold=False, max_new_tokens=None):
"""
Encode prompts for eval_sample
Input:
- task, template: task and template class
- train_samples, eval_sample: demonstrations and the actual sample
- tokenizer, max_length: tokenizer and max length
- sfc: generate prompts for calibration (surface form competition; https://arxiv.org/abs/2104.08315)
- icl_sfc: generate prompts for ICL version calibration
- generation: whether it is an generation task
- generation_with_gold: whether to include the generation-task gold answers (for training)
- max_new_tokens: max number of new tokens to generate so that we can save enough space
(only for generation tasks)
Output:
- encodings: a list of N lists of tokens. N is the number of options for classification/multiple-choice.
- option_lens: a list of N integers indicating the number of option tokens.
"""
# Demonstrations for ICL
train_prompts = [template.verbalize(sample, sample.correct_candidate).strip() for sample in train_samples]
train_prompts = task.train_sep.join(train_prompts).strip()
# sfc or icl_sfc indicates that this example is used for calibration
if sfc or icl_sfc:
encode_fn = template.encode_sfc; verbalize_fn = template.verbalize_sfc
else:
encode_fn = template.encode; verbalize_fn = template.verbalize
unverbalized_eval_prompt = encode_fn(eval_sample).strip(' ')
if not generation:
# We generate one prompt for each candidate (different classes in classification)
# or different choices in multiple-choice tasks
verbalized_eval_prompts = [verbalize_fn(eval_sample, cand).strip(' ') for cand in eval_sample.candidates]
unverbalized_eval_prompt_length = len(tokenizer.encode(unverbalized_eval_prompt))
option_lens = [(len(tokenizer.encode(verbalized_eval_prompt)) - unverbalized_eval_prompt_length) for verbalized_eval_prompt in verbalized_eval_prompts]
if sfc:
# Without demonstrations
final_prompts = verbalized_eval_prompts
else:
# With demonstrations
final_prompts = [(train_prompts + task.train_sep + eval_prompt).lstrip().strip(' ') for eval_prompt in verbalized_eval_prompts]
else:
assert not sfc and not icl_sfc, "Generation tasks do not support SFC"
if generation_with_gold:
verbalized_eval_prompts = [verbalize_fn(eval_sample, eval_sample.correct_candidate)]
unverbalized_eval_prompt_length = len(tokenizer.encode(unverbalized_eval_prompt))
option_lens = [(len(tokenizer.encode(verbalized_eval_prompt)) - unverbalized_eval_prompt_length) for verbalized_eval_prompt in verbalized_eval_prompts]
final_prompts = [(train_prompts + task.train_sep + eval_prompt).lstrip().strip(' ') for eval_prompt in verbalized_eval_prompts]
else:
option_lens = [0]
final_prompts = [(train_prompts + task.train_sep + unverbalized_eval_prompt).lstrip().strip(' ')]
# Tokenize
encodings = [tokenizer.encode(final_prompt) for final_prompt in final_prompts]
# Truncate (left truncate as demonstrations are less important)
if generation and max_new_tokens is not None:
max_length = max_length - max_new_tokens
if any([len(encoding) > max_length for encoding in encodings]):
logger.warn("Exceed max length")
if tokenizer.add_bos_token:
encodings = [encoding[0:1] + encoding[1:][-(max_length-1):] for encoding in encodings]
else:
encodings = [encoding[-max_length:] for encoding in encodings]
return encodings, option_lens
@dataclass
class ICLCollator:
"""
Collator for ICL
"""
tokenizer: PreTrainedTokenizerBase
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
if not isinstance(features[0], Mapping):
features = [vars(f) for f in features]
first = features[0]
batch = {}
pad_id = self.tokenizer.pad_token_id
pad_ids = {"input_ids": pad_id, "attention_mask": 0, "sfc_input_ids": pad_id, "sfc_attention_mask": 0, "labels": pad_id}
for key in first:
pp = pad_ids[key]
lens = [len(f[key]) for f in features]
max_len = max(lens)
feature = np.stack([np.pad(f[key], (0, max_len - lens[i]), "constant", constant_values=(0, pp)) for i, f in enumerate(features)])
padded_feature = torch.from_numpy(feature).long()
batch[key] = padded_feature
return batch
@dataclass
class DataCollatorWithPaddingAndNesting:
"""
Collator for training
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
features = [ff for f in features for ff in f]
batch = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
if "label" in batch:
batch["labels"] = batch["label"]
del batch["label"]
if "label_ids" in batch:
batch["labels"] = batch["label_ids"]
del batch["label_ids"]
return batch
@dataclass
class NondiffCollator(DataCollatorMixin):
"""
Collator for non-differentiable objectives
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
label_pad_token_id: int = -100
return_tensors: str = "pt"
def torch_call(self, features):
import torch
label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
no_labels_features = [{k: v for k, v in feature.items() if k != label_name and k != "gold"} for feature in features]
batch = self.tokenizer.pad(
no_labels_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
if labels is None:
return batch
sequence_length = batch["input_ids"].shape[1]
padding_side = self.tokenizer.padding_side
def to_list(tensor_or_iterable):
if isinstance(tensor_or_iterable, torch.Tensor):
return tensor_or_iterable.tolist()
return list(tensor_or_iterable)
if padding_side == "right":
batch[label_name] = [
to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
]
else:
batch[label_name] = [
[self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
]
batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
if "gold" in features[0]:
batch["gold"] = [feature["gold"] for feature in features]
return batch
class SIGUSR1Callback(transformers.TrainerCallback):
"""
This callback is used to save the model when a SIGUSR1 signal is received
(SLURM stop signal or a keyboard interruption signal).
"""
def __init__(self) -> None:
super().__init__()
self.signal_received = False
signal.signal(signal.SIGUSR1, self.handle_signal)
signal.signal(signal.SIGINT, self.handle_signal)
logger.warn("Handler registered")
def handle_signal(self, signum, frame):
self.signal_received = True
logger.warn("Signal received")
def on_step_end(self, args, state, control, **kwargs):
if self.signal_received:
control.should_save = True
control.should_training_stop = True
def on_train_end(self, args, state, control, **kwargs):
if self.signal_received:
exit(0)
@dataclass
class Prediction:
correct_candidate: Union[int, str]
predicted_candidate: Union[int, str]
@contextlib.contextmanager
def count_time(name):
logger.info("%s..." % name)
start_time = time.time()
try:
yield
finally:
logger.info("Done with %.2fs" % (time.time() - start_time))
@contextlib.contextmanager
def temp_seed(seed):
state = np.random.get_state()
np.random.seed(seed)
try:
yield
finally:
np.random.set_state(state)
class EnhancedJSONEncoder(json.JSONEncoder):
def default(self, o):
if is_dataclass(o):
return asdict(o)
return super().default(o)
def write_predictions_to_file(final_preds, output):
with open(output, "w") as f:
for pred in final_preds:
f.write(json.dumps(pred, cls=EnhancedJSONEncoder) + "\n")
def write_metrics_to_file(metrics, output):
json.dump(metrics, open(output, "w"), cls=EnhancedJSONEncoder, indent=4)
================================================
FILE: requirements.txt
================================================
brotli==1.0.9
certifi==2024.7.4
charset-normalizer==3.3.2
filelock==3.13.1
idna==3.7
Jinja2==3.1.4
MarkupSafe==2.1.3
numpy==1.26.4
Pillow==10.4.0
PySocks==1.7.1
PyYAML==6.0.1
requests==2.32.3
rich==14.0.0
setuptools==72.1.0
urllib3==2.2.2
wheel==0.43.0
accelerate==1.6.0
datasets==3.5.1
aiohttp==3.10.3
aiosignal==1.3.1
attrs==24.2.0
dill==0.3.8
frozenlist==1.4.1
fsspec==2024.5.0
huggingface-hub==0.24.5
joblib==1.4.2
multidict==6.0.5
multiprocess==0.70.16
opt-einsum==3.3.0
packaging==24.1
pandas==2.2.2
psutil==6.0.0
pyarrow==17.0.0
pyarrow-hotfix==0.6
python-dateutil==2.9.0.post0
pytz==2024.1
regex==2024.7.24
scikit-learn==1.5.1
scipy==1.14.0
six==1.16.0
threadpoolctl==3.5.0
tokenizers==0.21.1
tqdm==4.66.5
transformers==4.51.3
tzdata==2024.1
xxhash==3.4.1
yarl==1.9.4
nvidia-ml-py==12.570.86
trl==0.17.0
safetensors==0.5.2
================================================
FILE: script/add-copyright.py
================================================
import os
import datetime
import logging
current_year = datetime.datetime.now().year
owner = "liangyuwang"
logging.basicConfig(filename='license_addition_errors.log', level=logging.ERROR)
def add_license_header(file_path, comment_style):
try:
with open(file_path, 'r+', encoding='utf-8') as file:
content = file.read()
license_snippet = "Licensed under the Apache License, Version 2.0"
if license_snippet not in content:
header = "# Copyright Notice\n"
if comment_style == "block":
header = f"/* Copyright (c) {current_year} {owner}\n * Licensed under the Apache License, Version 2.0\n */\n\n"
elif comment_style == "line":
header = f"# Copyright (c) {current_year} {owner}\n# Licensed under the Apache License, Version 2.0\n\n"
file.seek(0, 0)
file.write(header + content)
except FileNotFoundError:
logging.error(f"File not found: {file_path}")
file_map = {
'.cpp': 'block',
'.h': 'block',
'.cu': 'block',
'.py': 'line',
'.cmake': 'line'
}
for root, dirs, files in os.walk("."):
for file in files:
ext = os.path.splitext(file)[1]
if ext in file_map:
add_license_header(os.path.join(root, file), file_map[ext])
elif 'CMakeLists.txt' in file:
add_license_header(os.path.join(root, file), 'line')
================================================
FILE: script/clear-pycache.sh
================================================
find . | grep -E "(/__pycache__$|\.pyc$|\.pyo$)" | xargs rm -rf
================================================
FILE: setup.py
================================================
from setuptools import setup, find_packages
with open('requirements.txt') as f:
requirements = f.read().splitlines()
setup(
name='zo2',
version='0.1.1',
author='liangyuwang',
author_email='liangyu.wang@kaust.edu.sa',
description='ZO2 (Zeroth-Order Offloading), a framework for full parameter fine-tuning 175B LLMs with 18GB GPU memory',
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
packages=find_packages(),
install_requires=requirements, # List of dependencies, read from requirements.txt
classifiers=[
'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: 3.12',
'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent',
],
python_requires='>=3.11',
include_package_data=True,
zip_safe=False
)
================================================
FILE: test/README.md
================================================
# Test
- Important Notice: For fine-tuning the **OPT-175B** model, ensure that your system is equipped with at least `18GB of GPU memory` and `600GB of CPU memory`.
## Example: MeZO-SGD on OPT Models
```shell
# compare memory
bash test/mezo_sgd/hf_opt/test_memory_train.sh
```
```shell
# compare throughput
bash test/mezo_sgd/hf_opt/test_speed_train.sh
```
```shell
# compare accuracy
bash test/mezo_sgd/hf_opt/test_acc_train.sh
```
## Supported Tests
In progress...
================================================
FILE: test/mezo_sgd/hf_gpt/trainer.py
================================================
# Copyright (c) 2025 liangyuwang
# Licensed under the Apache License, Version 2.0
================================================
FILE: test/mezo_sgd/hf_llama/trainer.py
================================================
# Copyright (c) 2025 liangyuwang
# Licensed under the Apache License, Version 2.0
================================================
FILE: test/mezo_sgd/hf_opt/record_zo2_memory.sh
================================================
#!/bin/bash
set -e
set -o pipefail
model_names=("opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b")
task_ids=("causalLM")
# ANSI color codes
GREEN='\033[0;32m'
RED='\033[0;31m'
NC='\033[0m'
for model_name in "${model_names[@]}"
do
for task_id in "${task_ids[@]}"
do
echo "Testing model_name: $model_name, task_id: $task_id"
CMD2="python test/mezo_sgd/hf_opt/test_memory.py --model_name $model_name --task $task_id --zo_method zo2 --max_steps 30"
OUT2="/tmp/output2_${model_name}_${task_id}.txt"
$CMD2 2>&1 | tee $OUT2
echo "Recording Peak GPU and CPU Memory usage..."
max_mem1=$(grep 'Peak GPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1)
max_mem2=$(grep 'Peak CPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1)
if [ -z "$max_mem1" ] || [ -z "$max_mem2" ]; then
echo "Could not find memory usage data in the output."
else
echo -e "Model: $model_name, Task: $task_id"
echo -e "ZO2 peak GPU memory: ${GREEN}$max_mem1 MB${NC}"
echo -e "ZO2 peak CPU memory: ${GREEN}$max_mem2 MB${NC}"
fi
rm $OUT2
done
done
================================================
FILE: test/mezo_sgd/hf_opt/record_zo2_speed.sh
================================================
#!/bin/bash
set -e
set -o pipefail
model_names=("opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b")
task_ids=("causalLM")
# ANSI color codes
GREEN='\033[0;32m'
RED='\033[0;31m'
NC='\033[0m'
for model_name in "${model_names[@]}"
do
for task_id in "${task_ids[@]}"
do
echo "Testing model_name: $model_name, task_id: $task_id"
CMD2="python test/mezo_sgd/hf_opt/test_speed.py --model_name $model_name --task $task_id --zo_method zo2 --max_steps 30"
OUT2="/tmp/output2_${model_name}_${task_id}.txt"
$CMD2 2>&1 | tee $OUT2
echo "Recording throughput..."
# Count the total number of lines and determine the number of iteration lines
total_lines2=$(wc -l < $OUT2)
iter_lines2=$(grep -c 'Time cost after iteration' $OUT2)
# Calculate the starting line for the last 50% of iterations
start_line2=$(($total_lines2 - $iter_lines2 + $(($iter_lines2 / 2 + 1))))
# Calculate average tokens per second for the last 50% of the iterations
avg_tok_s2=$(tail -n +$start_line2 $OUT2 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}')
echo -e "Model: $model_name, Task: $task_id"
echo -e "ZO2 average throughput (last 50% iterations): ${GREEN}$avg_tok_s2 tok/s${NC}"
rm $OUT2
done
done
================================================
FILE: test/mezo_sgd/hf_opt/test_acc.py
================================================
# Copyright (c) 2025 liangyuwang
# Licensed under the Apache License, Version 2.0
import sys
sys.path.append("../zo2")
import torch
from tqdm import tqdm
from zo2.config.mezo_sgd import MeZOSGDConfig
from zo2.model.huggingface.opt.mezo_sgd import zo, zo2
from zo2.utils.utils import seed_everything
from utils import (
OPTConfigs,
prepare_data_for_causalLM,
prepare_data_for_sequence_classification,
prepare_data_for_question_answering,
model_size,
get_args
)
def train_mezo_sgd_causalLM(model_config, zo_config, device='cuda'):
input_ids, labels = prepare_data_for_causalLM(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo.OPTForCausalLM(model_config).to(device)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_eval()
model.zo_train()
loss = model(input_ids=input_ids, labels=labels)
res = "Iteration {}, loss: {}, projected grad: {}"
tqdm.write(res.format(i, loss, model.opt.projected_grad))
def train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'):
input_ids, labels = prepare_data_for_causalLM(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo2.OPTForCausalLM(model_config)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_eval()
model.zo_train()
loss = model(input_ids=input_ids, labels=labels)
res = "Iteration {}, loss: {}, projected grad: {}"
tqdm.write(res.format(i, loss, model.opt.projected_grad))
def eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda'):
input_ids, labels = prepare_data_for_causalLM(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo.OPTForCausalLM(model_config).to(device)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_train()
model.zo_eval()
loss = model(input_ids=input_ids, labels=labels)["loss"]
res = "Iteration {}, loss: {}"
tqdm.write(res.format(i, loss))
def eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'):
input_ids, labels = prepare_data_for_causalLM(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo2.OPTForCausalLM(model_config)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_train()
model.zo_eval()
loss = model(input_ids=input_ids, labels=labels)["loss"]
res = "Iteration {}, loss: {}"
tqdm.write(res.format(i, loss))
def train_mezo_sgd_sequence_classification(model_config, zo_config, device='cuda'):
input_ids, labels = prepare_data_for_sequence_classification(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo.OPTForSequenceClassification(model_config).to(device)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_eval()
model.zo_train()
loss = model(input_ids=input_ids, labels=labels)
res = "Iteration {}, loss: {}, projected grad: {}"
tqdm.write(res.format(i, loss, model.opt.projected_grad))
def train_mezo2_sgd_sequence_classification(model_config, zo_config, device='cuda'):
input_ids, labels = prepare_data_for_sequence_classification(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo2.OPTForSequenceClassification(model_config).to(device)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_eval()
model.zo_train()
loss = model(input_ids=input_ids, labels=labels)
res = "Iteration {}, loss: {}, projected grad: {}"
tqdm.write(res.format(i, loss, model.opt.projected_grad))
def eval_mezo_sgd_sequence_classification(model_config, zo_config, device='cuda'):
input_ids, labels = prepare_data_for_sequence_classification(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo.OPTForSequenceClassification(model_config).to(device)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_train()
model.zo_eval()
loss = model(input_ids=input_ids, labels=labels)["loss"]
res = "Iteration {}, loss: {}"
tqdm.write(res.format(i, loss))
def eval_mezo2_sgd_sequence_classification(model_config, zo_config, device='cuda'):
input_ids, labels = prepare_data_for_sequence_classification(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo2.OPTForSequenceClassification(model_config).to(device)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_train()
model.zo_eval()
loss = model(input_ids=input_ids, labels=labels)["loss"]
res = "Iteration {}, loss: {}"
tqdm.write(res.format(i, loss))
def train_mezo_sgd_question_answering(model_config, zo_config, device='cuda'):
input_ids, start_positions, end_positions = prepare_data_for_question_answering(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo.OPTForQuestionAnswering(model_config).to("cuda")
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_eval()
model.zo_train()
loss = model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions)
res = "Iteration {}, loss: {}, projected grad: {}"
tqdm.write(res.format(i, loss, model.opt.projected_grad))
def train_mezo2_sgd_question_answering(model_config, zo_config, device='cuda'):
input_ids, start_positions, end_positions = prepare_data_for_question_answering(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo2.OPTForQuestionAnswering(model_config).to("cuda")
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_eval()
model.zo_train()
loss = model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions)
res = "Iteration {}, loss: {}, projected grad: {}"
tqdm.write(res.format(i, loss, model.opt.projected_grad))
def eval_mezo_sgd_question_answering(model_config, zo_config, device='cuda'):
input_ids, start_positions, end_positions = prepare_data_for_question_answering(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo.OPTForQuestionAnswering(model_config).to("cuda")
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_train()
model.zo_eval()
loss = model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions)["loss"]
res = "Iteration {}, loss: {}"
tqdm.write(res.format(i, loss))
def eval_mezo2_sgd_question_answering(model_config, zo_config, device='cuda'):
input_ids, start_positions, end_positions = prepare_data_for_question_answering(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo2.OPTForQuestionAnswering(model_config).to("cuda")
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_train()
model.zo_eval()
loss = model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions)["loss"]
res = "Iteration {}, loss: {}"
tqdm.write(res.format(i, loss))
def test_mezo_sgd_causalLM_training():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
working_device=args.working_device)
zo_cfg.zo2 = False
train_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device)
def test_mezo2_sgd_causalLM_training():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
offloading_device=args.offloading_device, working_device=args.working_device)
zo_cfg.zo2 = True
zo_cfg.overlap = args.overlap=="all"
train_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device)
def test_mezo_sgd_causalLM_eval():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
working_device=args.working_device)
zo_cfg.zo2 = False
eval_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device)
def test_mezo2_sgd_causalLM_eval():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
offloading_device=args.offloading_device, working_device=args.working_device)
zo_cfg.zo2 = True
zo_cfg.overlap = args.overlap=="all"
eval_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device)
def test_mezo_sgd_sequence_classification_training():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
working_device=args.working_device)
zo_cfg.zo2 = False
train_mezo_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device)
def test_mezo2_sgd_sequence_classification_training():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
offloading_device=args.offloading_device, working_device=args.working_device)
zo_cfg.zo2 = True
zo_cfg.overlap = args.overlap=="all"
train_mezo2_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device)
def test_mezo_sgd_sequence_classification_eval():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
working_device=args.working_device)
zo_cfg.zo2 = False
eval_mezo_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device)
def test_mezo2_sgd_sequence_classification_eval():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
offloading_device=args.offloading_device, working_device=args.working_device)
zo_cfg.zo2 = True
zo_cfg.overlap = args.overlap=="all"
eval_mezo2_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device)
def test_mezo_sgd_question_answering_training():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
working_device=args.working_device)
zo_cfg.zo2 = False
train_mezo_sgd_question_answering(model_config, zo_cfg, device=args.working_device)
def test_mezo2_sgd_question_answering_training():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
offloading_device=args.offloading_device, working_device=args.working_device)
zo_cfg.zo2 = True
zo_cfg.overlap = args.overlap=="all"
train_mezo2_sgd_question_answering(model_config, zo_cfg, device=args.working_device)
def test_mezo_sgd_question_answering_eval():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
working_device=args.working_device)
zo_cfg.zo2 = False
eval_mezo_sgd_question_answering(model_config, zo_cfg, device=args.working_device)
def test_mezo2_sgd_question_answering_eval():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
offloading_device=args.offloading_device, working_device=args.working_device)
zo_cfg.zo2 = True
zo_cfg.overlap = args.overlap=="all"
eval_mezo2_sgd_question_answering(model_config, zo_cfg, device=args.working_device)
if __name__=="__main__":
args = get_args()
original_dtype = torch.get_default_dtype()
if args.zo_method == "zo":
if args.task == "causalLM":
if args.eval:
test_mezo_sgd_causalLM_eval()
else:
test_mezo_sgd_causalLM_training()
elif args.task == "sequence_classification":
if args.eval:
test_mezo_sgd_sequence_classification_eval()
else:
test_mezo_sgd_sequence_classification_training()
elif args.task == "question_answering":
if args.eval:
test_mezo_sgd_question_answering_eval()
else:
test_mezo_sgd_question_answering_training()
else:
raise NotImplementedError(f"Task {args.task} is unsupported.")
elif args.zo_method == "zo2":
if args.task == "causalLM":
if args.eval:
test_mezo2_sgd_causalLM_eval()
else:
test_mezo2_sgd_causalLM_training()
elif args.task == "sequence_classification":
if args.eval:
test_mezo2_sgd_sequence_classification_eval()
else:
test_mezo2_sgd_sequence_classification_training()
elif args.task == "question_answering":
if args.eval:
test_mezo2_sgd_question_answering_eval()
else:
test_mezo2_sgd_question_answering_training()
else:
raise NotImplementedError(f"Task {args.task} is unsupported.")
else:
raise NotImplementedError
================================================
FILE: test/mezo_sgd/hf_opt/test_acc_eval.sh
================================================
#!/bin/bash
set -e
set -o pipefail
model_names=("opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b")
task_ids=("causalLM" "sequence_classification" "question_answering")
# ANSI color codes
GREEN='\033[0;32m'
RED='\033[0;31m'
NC='\033[0m'
for model_name in "${model_names[@]}"
do
for task_id in "${task_ids[@]}"
do
echo "Testing model_name: $model_name, task_id: $task_id"
if [ "$task_id" == "causalLM" ]; then
lr=1e-4
else
lr=1e-7
fi
CMD1="python test/mezo_sgd/hf_opt/test_acc.py --model_name $model_name --task $task_id --zo_method zo --lr $lr --eval --max_steps 30"
CMD2="python test/mezo_sgd/hf_opt/test_acc.py --model_name $model_name --task $task_id --zo_method zo2 --lr $lr --eval --max_steps 30"
OUT1="/tmp/output1_${model_name}_${task_id}.txt"
OUT2="/tmp/output2_${model_name}_${task_id}.txt"
$CMD1 2>&1 | tee $OUT1
$CMD2 2>&1 | tee $OUT2
echo "Comparing outputs..."
echo -e "Model: $model_name, Task: $task_id"
paste <(grep 'Iteration' $OUT1) <(grep 'Iteration' $OUT2) | awk -v green="$GREEN" -v red="$RED" -v nc="$NC" '{
split($4, loss1, ",");
split($8, loss2, ",");
diff_loss = loss1[1] - loss2[1];
if (loss1[1] == loss2[1])
printf "Iteration %s: %s✓ loss match.%s\n", $2, green, nc;
else
printf "Iteration %s: %s✗ Mismatch! ZO (loss): (%s), ZO2 (loss): (%s) Loss diff: %.6f%s\n", $2, red, loss1[1], loss2[1], diff_loss, nc;
}'
rm $OUT1 $OUT2
done
done
================================================
FILE: test/mezo_sgd/hf_opt/test_acc_train.sh
================================================
#!/bin/bash
set -e
set -o pipefail
model_names=("opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b")
task_ids=("causalLM" "sequence_classification" "question_answering")
# ANSI color codes
GREEN='\033[0;32m'
RED='\033[0;31m'
NC='\033[0m'
for model_name in "${model_names[@]}"
do
for task_id in "${task_ids[@]}"
do
echo "Testing model_name: $model_name, task_id: $task_id"
if [ "$task_id" == "causalLM" ]; then
lr=1e-4
else
lr=1e-7
fi
CMD1="python test/mezo_sgd/hf_opt/test_acc.py --model_name $model_name --task $task_id --zo_method zo --lr $lr"
CMD2="python test/mezo_sgd/hf_opt/test_acc.py --model_name $model_name --task $task_id --zo_method zo2 --lr $lr"
OUT1="/tmp/output1_${model_name}_${task_id}.txt"
OUT2="/tmp/output2_${model_name}_${task_id}.txt"
$CMD1 2>&1 | tee $OUT1
$CMD2 2>&1 | tee $OUT2
echo "Comparing outputs..."
echo -e "Model: $model_name, Task: $task_id"
paste <(grep 'Iteration' $OUT1) <(grep 'Iteration' $OUT2) | awk -v green="$GREEN" -v red="$RED" -v nc="$NC" '{
split($4, loss1, ",");
split($7, proj1, ",");
split($11, loss2, ",");
split($14, proj2, ",");
diff_loss = loss1[1] - loss2[1];
diff_proj = proj1[1] - proj2[1];
if (loss1[1] == loss2[1] && proj1[1] == proj2[1])
printf "Iteration %d: %s✓ loss and projected grad match.%s\n", $2, green, nc;
else
printf "Iteration %d: %s✗ Mismatch! ZO (loss, grad): (%s, %s), ZO2 (loss, grad): (%s, %s)\n \tLoss diff: %.6f, Proj grad diff: %.6f%s\n", $2, red, loss1[1], proj1[1], loss2[1], proj2[1], diff_loss, diff_proj, nc;
}'
rm $OUT1 $OUT2
done
done
================================================
FILE: test/mezo_sgd/hf_opt/test_memory.py
================================================
# Copyright (c) 2025 liangyuwang
# Licensed under the Apache License, Version 2.0
import sys
sys.path.append("../zo2")
import torch
from tqdm import tqdm
from zo2.config.mezo_sgd import MeZOSGDConfig
from zo2.model.huggingface.opt.mezo_sgd import zo, zo2
from zo2.utils.utils import seed_everything
from utils import (
OPTConfigs,
prepare_data_for_causalLM,
prepare_data_for_sequence_classification,
prepare_data_for_question_answering,
model_size,
get_args,
check_peak_gpu_memory_usage,
reset_peak_cpu_memory_usage,
check_and_update_peak_cpu_memory_usage
)
def train_mezo_sgd_causalLM(model_config, zo_config, device='cuda:0'):
input_ids, labels = prepare_data_for_causalLM(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo.OPTForCausalLM(model_config).to(device)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
torch.cuda.reset_peak_memory_stats()
reset_peak_cpu_memory_usage()
for i in tqdm(range(args.max_steps)):
model.zo_train()
model(input_ids=input_ids, labels=labels)
check_peak_gpu_memory_usage(i, int(device[-1]), True)
check_and_update_peak_cpu_memory_usage(i, True)
def train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda:0'):
input_ids, labels = prepare_data_for_causalLM(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo2.OPTForCausalLM(model_config)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
torch.cuda.reset_peak_memory_stats()
reset_peak_cpu_memory_usage()
for i in tqdm(range(args.max_steps)):
model.zo_train()
model(input_ids=input_ids, labels=labels)
check_peak_gpu_memory_usage(i, int(device[-1]), True)
check_and_update_peak_cpu_memory_usage(i, True)
def eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda:0'):
input_ids, labels = prepare_data_for_causalLM(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo.OPTForCausalLM(model_config).to(device)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
torch.cuda.reset_peak_memory_stats()
reset_peak_cpu_memory_usage()
for i in tqdm(range(args.max_steps)):
model.zo_eval()
model(input_ids=input_ids, labels=labels)
check_peak_gpu_memory_usage(i, int(device[-1]), True)
check_and_update_peak_cpu_memory_usage(i, True)
def eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda:0'):
input_ids, labels = prepare_data_for_causalLM(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo2.OPTForCausalLM(model_config)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
torch.cuda.reset_peak_memory_stats()
reset_peak_cpu_memory_usage()
for i in tqdm(range(args.max_steps)):
model.zo_eval()
model(input_ids=input_ids, labels=labels)
check_peak_gpu_memory_usage(i, int(device[-1]), True)
check_and_update_peak_cpu_memory_usage(i, True)
def train_mezo_sgd_sequence_classification(model_config, zo_config, device='cuda:0'):
input_ids, labels = prepare_data_for_sequence_classification(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo.OPTForSequenceClassification(model_config).to(device)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
torch.cuda.reset_peak_memory_stats()
reset_peak_cpu_memory_usage()
for i in tqdm(range(args.max_steps)):
model.zo_train()
model(input_ids=input_ids, labels=labels)
check_peak_gpu_memory_usage(i, int(device[-1]), True)
check_and_update_peak_cpu_memory_usage(i, True)
def train_mezo2_sgd_sequence_classification(model_config, zo_config, device='cuda:0'):
input_ids, labels = prepare_data_for_sequence_classification(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo2.OPTForSequenceClassification(model_config).to(device)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
torch.cuda.reset_peak_memory_stats()
reset_peak_cpu_memory_usage()
for i in tqdm(range(args.max_steps)):
model.zo_train()
model(input_ids=input_ids, labels=labels)
check_peak_gpu_memory_usage(i, int(device[-1]), True)
check_and_update_peak_cpu_memory_usage(i, True)
def eval_mezo_sgd_sequence_classification(model_config, zo_config, device='cuda:0'):
input_ids, labels = prepare_data_for_sequence_classification(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo.OPTForSequenceClassification(model_config).to(device)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
torch.cuda.reset_peak_memory_stats()
reset_peak_cpu_memory_usage()
for i in tqdm(range(args.max_steps)):
model.zo_eval()
model(input_ids=input_ids, labels=labels)
check_peak_gpu_memory_usage(i, int(device[-1]), True)
check_and_update_peak_cpu_memory_usage(i, True)
def eval_mezo2_sgd_sequence_classification(model_config, zo_config, device='cuda:0'):
input_ids, labels = prepare_data_for_sequence_classification(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo2.OPTForSequenceClassification(model_config).to(device)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
torch.cuda.reset_peak_memory_stats()
reset_peak_cpu_memory_usage()
for i in tqdm(range(args.max_steps)):
model.zo_eval()
model(input_ids=input_ids, labels=labels)
check_peak_gpu_memory_usage(i, int(device[-1]), True)
check_and_update_peak_cpu_memory_usage(i, True)
def train_mezo_sgd_question_answering(model_config, zo_config, device='cuda:0'):
input_ids, start_positions, end_positions = prepare_data_for_question_answering(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo.OPTForQuestionAnswering(model_config).to("cuda")
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
torch.cuda.reset_peak_memory_stats()
reset_peak_cpu_memory_usage()
for i in tqdm(range(args.max_steps)):
model.zo_train()
model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions)
check_peak_gpu_memory_usage(i, int(device[-1]), True)
check_and_update_peak_cpu_memory_usage(i, True)
def train_mezo2_sgd_question_answering(model_config, zo_config, device='cuda:0'):
input_ids, start_positions, end_positions = prepare_data_for_question_answering(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo2.OPTForQuestionAnswering(model_config).to("cuda")
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
torch.cuda.reset_peak_memory_stats()
reset_peak_cpu_memory_usage()
for i in tqdm(range(args.max_steps)):
model.zo_train()
model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions)
check_peak_gpu_memory_usage(i, int(device[-1]), True)
check_and_update_peak_cpu_memory_usage(i, True)
def eval_mezo_sgd_question_answering(model_config, zo_config, device='cuda:0'):
input_ids, start_positions, end_positions = prepare_data_for_question_answering(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo.OPTForQuestionAnswering(model_config).to("cuda")
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
torch.cuda.reset_peak_memory_stats()
reset_peak_cpu_memory_usage()
for i in tqdm(range(args.max_steps)):
model.zo_eval()
model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions)
check_peak_gpu_memory_usage(i, int(device[-1]), True)
check_and_update_peak_cpu_memory_usage(i, True)
def eval_mezo2_sgd_question_answering(model_config, zo_config, device='cuda:0'):
input_ids, start_positions, end_positions = prepare_data_for_question_answering(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo2.OPTForQuestionAnswering(model_config).to("cuda")
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
torch.cuda.reset_peak_memory_stats()
reset_peak_cpu_memory_usage()
for i in tqdm(range(args.max_steps)):
model.zo_eval()
model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions)
check_peak_gpu_memory_usage(i, int(device[-1]), True)
check_and_update_peak_cpu_memory_usage(i, True)
def test_mezo_sgd_causalLM_training():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
model_config.max_position_embeddings = args.sequence_length
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
working_device=args.working_device)
zo_cfg.zo2 = False
train_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device)
def test_mezo2_sgd_causalLM_training():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
model_config.max_position_embeddings = args.sequence_length
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
offloading_device=args.offloading_device, working_device=args.working_device)
zo_cfg.zo2 = True
train_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device)
def test_mezo_sgd_causalLM_eval():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
model_config.max_position_embeddings = args.sequence_length
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
working_device=args.working_device)
zo_cfg.zo2 = False
eval_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device)
def test_mezo2_sgd_causalLM_eval():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
model_config.max_position_embeddings = args.sequence_length
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
offloading_device=args.offloading_device, working_device=args.working_device)
zo_cfg.zo2 = True
eval_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device)
def test_mezo_sgd_sequence_classification_training():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
model_config.max_position_embeddings = args.sequence_length
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
working_device=args.working_device)
zo_cfg.zo2 = False
train_mezo_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device)
def test_mezo2_sgd_sequence_classification_training():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
model_config.max_position_embeddings = args.sequence_length
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
offloading_device=args.offloading_device, working_device=args.working_device)
zo_cfg.zo2 = True
train_mezo2_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device)
def test_mezo_sgd_sequence_classification_eval():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
model_config.max_position_embeddings = args.sequence_length
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
working_device=args.working_device)
zo_cfg.zo2 = False
eval_mezo_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device)
def test_mezo2_sgd_sequence_classification_eval():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
model_config.max_position_embeddings = args.sequence_length
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
offloading_device=args.offloading_device, working_device=args.working_device)
zo_cfg.zo2 = True
eval_mezo2_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device)
def test_mezo_sgd_question_answering_training():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
model_config.max_position_embeddings = args.sequence_length
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
working_device=args.working_device)
zo_cfg.zo2 = False
train_mezo_sgd_question_answering(model_config, zo_cfg, device=args.working_device)
def test_mezo2_sgd_question_answering_training():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
model_config.max_position_embeddings = args.sequence_length
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
offloading_device=args.offloading_device, working_device=args.working_device)
zo_cfg.zo2 = True
train_mezo2_sgd_question_answering(model_config, zo_cfg, device=args.working_device)
def test_mezo_sgd_question_answering_eval():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
model_config.max_position_embeddings = args.sequence_length
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
working_device=args.working_device)
zo_cfg.zo2 = False
eval_mezo_sgd_question_answering(model_config, zo_cfg, device=args.working_device)
def test_mezo2_sgd_question_answering_eval():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
model_config.max_position_embeddings = args.sequence_length
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
offloading_device=args.offloading_device, working_device=args.working_device)
zo_cfg.zo2 = True
eval_mezo2_sgd_question_answering(model_config, zo_cfg, device=args.working_device)
if __name__=="__main__":
args = get_args()
original_dtype = torch.get_default_dtype()
if args.zo_method == "zo":
if args.task == "causalLM":
if args.eval:
test_mezo_sgd_causalLM_eval()
else:
test_mezo_sgd_causalLM_training()
elif args.task == "sequence_classification":
if args.eval:
test_mezo_sgd_sequence_classification_eval()
else:
test_mezo_sgd_sequence_classification_training()
elif args.task == "question_answering":
if args.eval:
test_mezo_sgd_question_answering_eval()
else:
test_mezo_sgd_question_answering_training()
else:
raise NotImplementedError(f"Task {args.task} is unsupported.")
elif args.zo_method == "zo2":
if args.task == "causalLM":
if args.eval:
test_mezo2_sgd_causalLM_eval()
else:
test_mezo2_sgd_causalLM_training()
elif args.task == "sequence_classification":
if args.eval:
test_mezo2_sgd_sequence_classification_eval()
else:
test_mezo2_sgd_sequence_classification_training()
elif args.task == "question_answering":
if args.eval:
test_mezo2_sgd_question_answering_eval()
else:
test_mezo2_sgd_question_answering_training()
else:
raise NotImplementedError(f"Task {args.task} is unsupported.")
else:
raise NotImplementedError
================================================
FILE: test/mezo_sgd/hf_opt/test_memory_eval.sh
================================================
#!/bin/bash
set -e
set -o pipefail
model_names=("opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b")
task_ids=("causalLM")
# ANSI color codes
GREEN='\033[0;32m'
RED='\033[0;31m'
NC='\033[0m'
for model_name in "${model_names[@]}"
do
for task_id in "${task_ids[@]}"
do
echo "Testing model_name: $model_name, task_id: $task_id"
CMD1="python test/mezo_sgd/hf_opt/test_memory.py --model_name $model_name --task $task_id --zo_method zo --max_steps 30 --eval"
CMD2="python test/mezo_sgd/hf_opt/test_memory.py --model_name $model_name --task $task_id --zo_method zo2 --max_steps 30 --eval"
OUT1="/tmp/output1_${model_name}_${task_id}.txt"
OUT2="/tmp/output2_${model_name}_${task_id}.txt"
$CMD1 2>&1 | tee $OUT1
$CMD2 2>&1 | tee $OUT2
echo "Analyzing Peak GPU Memory usage..."
max_mem1=$(grep 'Peak GPU Memory' $OUT1 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1)
max_mem2=$(grep 'Peak GPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1)
if [ -z "$max_mem1" ] || [ -z "$max_mem2" ]; then
echo "Could not find memory usage data in the output."
else
ratio=$(echo "scale=2; $max_mem2 / $max_mem1 * 100" | bc)
echo -e "Model: $model_name, Task: $task_id"
echo -e "ZO peak GPU memory: ${GREEN}$max_mem1 MB${NC}"
echo -e "ZO2 peak GPU memory: ${GREEN}$max_mem2 MB${NC}"
echo -e "Memory usage ratio of ZO2 to ZO: ${GREEN}$ratio%${NC}"
fi
rm $OUT1 $OUT2
done
done
================================================
FILE: test/mezo_sgd/hf_opt/test_memory_train.sh
================================================
#!/bin/bash
set -e
set -o pipefail
model_names=("opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b")
task_ids=("causalLM")
# ANSI color codes
GREEN='\033[0;32m'
RED='\033[0;31m'
NC='\033[0m'
for model_name in "${model_names[@]}"
do
for task_id in "${task_ids[@]}"
do
echo "Testing model_name: $model_name, task_id: $task_id"
CMD1="python test/mezo_sgd/hf_opt/test_memory.py --model_name $model_name --task $task_id --zo_method zo --max_steps 30"
CMD2="python test/mezo_sgd/hf_opt/test_memory.py --model_name $model_name --task $task_id --zo_method zo2 --max_steps 30"
OUT1="/tmp/output1_${model_name}_${task_id}.txt"
OUT2="/tmp/output2_${model_name}_${task_id}.txt"
$CMD1 2>&1 | tee $OUT1
$CMD2 2>&1 | tee $OUT2
echo "Analyzing Peak GPU Memory usage..."
max_mem1=$(grep 'Peak GPU Memory' $OUT1 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1)
max_mem2=$(grep 'Peak GPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1)
if [ -z "$max_mem1" ] || [ -z "$max_mem2" ]; then
echo "Could not find memory usage data in the output."
else
ratio=$(echo "scale=2; $max_mem2 / $max_mem1 * 100" | bc)
echo -e "Model: $model_name, Task: $task_id"
echo -e "ZO peak GPU memory: ${GREEN}$max_mem1 MB${NC}"
echo -e "ZO2 peak GPU memory: ${GREEN}$max_mem2 MB${NC}"
echo -e "Memory usage ratio of ZO2 to ZO: ${GREEN}$ratio%${NC}"
fi
rm $OUT1 $OUT2
done
done
================================================
FILE: test/mezo_sgd/hf_opt/test_scheduler_acc_eval.sh
================================================
#!/bin/bash
set -e
set -o pipefail
model_names=("opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b")
task_ids=("causalLM" "sequence_classification" "question_answering")
# ANSI color codes
GREEN='\033[0;32m'
RED='\033[0;31m'
NC='\033[0m'
for model_name in "${model_names[@]}"
do
for task_id in "${task_ids[@]}"
do
echo "Testing model_name: $model_name, task_id: $task_id"
if [ "$task_id" == "causalLM" ]; then
lr=1e-4
else
lr=1e-7
fi
CMD1="python test/mezo_sgd/hf_opt/test_acc.py --model_name $model_name --task $task_id --zo_method zo2 --lr $lr --eval --overlap no --max_steps 30"
CMD2="python test/mezo_sgd/hf_opt/test_acc.py --model_name $model_name --task $task_id --zo_method zo2 --lr $lr --eval --max_steps 30"
OUT1="/tmp/output1_${model_name}_${task_id}.txt"
OUT2="/tmp/output2_${model_name}_${task_id}.txt"
$CMD1 2>&1 | tee $OUT1
$CMD2 2>&1 | tee $OUT2
echo "Comparing outputs..."
echo -e "Model: $model_name, Task: $task_id"
paste <(grep 'Iteration' $OUT1) <(grep 'Iteration' $OUT2) | awk -v green="$GREEN" -v red="$RED" -v nc="$NC" '{
split($4, loss1, ",");
split($8, loss2, ",");
diff_loss = loss1[1] - loss2[1];
if (loss1[1] == loss2[1])
printf "Iteration %s: %s✓ loss match.%s\n", $2, green, nc;
else
printf "Iteration %s: %s✗ Mismatch! ZO (loss): (%s), ZO2 (loss): (%s) Loss diff: %.6f%s\n", $2, red, loss1[1], loss2[1], diff_loss, nc;
}'
rm $OUT1 $OUT2
done
done
================================================
FILE: test/mezo_sgd/hf_opt/test_scheduler_acc_train.sh
================================================
#!/bin/bash
set -e
set -o pipefail
model_names=("opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b")
task_ids=("causalLM" "sequence_classification" "question_answering")
# ANSI color codes
GREEN='\033[0;32m'
RED='\033[0;31m'
NC='\033[0m'
for model_name in "${model_names[@]}"
do
for task_id in "${task_ids[@]}"
do
echo "Testing model_name: $model_name, task_id: $task_id"
if [ "$task_id" == "causalLM" ]; then
lr=1e-4
else
lr=1e-7
fi
CMD1="python test/mezo_sgd/hf_opt/test_acc.py --model_name $model_name --task $task_id --zo_method zo2 --lr $lr --overlap no --max_steps 30"
CMD2="python test/mezo_sgd/hf_opt/test_acc.py --model_name $model_name --task $task_id --zo_method zo2 --lr $lr --max_steps 30"
OUT1="/tmp/output1_${model_name}_${task_id}.txt"
OUT2="/tmp/output2_${model_name}_${task_id}.txt"
$CMD1 2>&1 | tee $OUT1
$CMD2 2>&1 | tee $OUT2
echo "Comparing outputs..."
echo -e "Model: $model_name, Task: $task_id"
paste <(grep 'Iteration' $OUT1) <(grep 'Iteration' $OUT2) | awk -v green="$GREEN" -v red="$RED" -v nc="$NC" '{
split($4, loss1, ",");
split($7, proj1, ",");
split($11, loss2, ",");
split($14, proj2, ",");
diff_loss = loss1[1] - loss2[1];
diff_proj = proj1[1] - proj2[1];
if (loss1[1] == loss2[1] && proj1[1] == proj2[1])
printf "Iteration %d: %s✓ loss and projected grad match.%s\n", $2, green, nc;
else
printf "Iteration %d: %s✗ Mismatch! Non-Overlap (loss, grad): (%s, %s), Overlap (loss, grad): (%s, %s)\n \tLoss diff: %.6f, Proj grad diff: %.6f%s\n", $2, red, loss1[1], proj1[1], loss2[1], proj2[1], diff_loss, diff_proj, nc;
}'
rm $OUT1 $OUT2
done
done
================================================
FILE: test/mezo_sgd/hf_opt/test_speed.py
================================================
# Copyright (c) 2025 liangyuwang
# Licensed under the Apache License, Version 2.0
import sys
sys.path.append("../zo2")
import torch
from tqdm import tqdm
from zo2.config.mezo_sgd import MeZOSGDConfig
from zo2.model.huggingface.opt.mezo_sgd import zo, zo2
from zo2.utils.utils import seed_everything
from utils import (
OPTConfigs,
prepare_data_for_causalLM,
prepare_data_for_sequence_classification,
prepare_data_for_question_answering,
model_size,
get_args,
check_throughput
)
def train_mezo_sgd_causalLM(model_config, zo_config, device='cuda'):
input_ids, labels = prepare_data_for_causalLM(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo.OPTForCausalLM(model_config).to(device)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_train()
check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True)
def train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'):
input_ids, labels = prepare_data_for_causalLM(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo2.OPTForCausalLM(model_config)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_train()
check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True)
def eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda'):
input_ids, labels = prepare_data_for_causalLM(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo.OPTForCausalLM(model_config).to(device)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_eval()
check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True)
def eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'):
input_ids, labels = prepare_data_for_causalLM(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo2.OPTForCausalLM(model_config)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_eval()
check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True)
def train_mezo_sgd_sequence_classification(model_config, zo_config, device='cuda'):
input_ids, labels = prepare_data_for_sequence_classification(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo.OPTForSequenceClassification(model_config).to(device)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_train()
check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True)
def train_mezo2_sgd_sequence_classification(model_config, zo_config, device='cuda'):
input_ids, labels = prepare_data_for_sequence_classification(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo2.OPTForSequenceClassification(model_config).to(device)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_train()
check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True)
def eval_mezo_sgd_sequence_classification(model_config, zo_config, device='cuda'):
input_ids, labels = prepare_data_for_sequence_classification(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo.OPTForSequenceClassification(model_config).to(device)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_eval()
check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True)
def eval_mezo2_sgd_sequence_classification(model_config, zo_config, device='cuda'):
input_ids, labels = prepare_data_for_sequence_classification(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo2.OPTForSequenceClassification(model_config).to(device)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_eval()
check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True)
def train_mezo_sgd_question_answering(model_config, zo_config, device='cuda'):
input_ids, start_positions, end_positions = prepare_data_for_question_answering(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo.OPTForQuestionAnswering(model_config).to("cuda")
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_train()
check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, start_positions=start_positions, end_positions=end_positions, use_tqdm=True)
def train_mezo2_sgd_question_answering(model_config, zo_config, device='cuda'):
input_ids, start_positions, end_positions = prepare_data_for_question_answering(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo2.OPTForQuestionAnswering(model_config).to("cuda")
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_train()
check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, start_positions=start_positions, end_positions=end_positions, use_tqdm=True)
def eval_mezo_sgd_question_answering(model_config, zo_config, device='cuda'):
input_ids, start_positions, end_positions = prepare_data_for_question_answering(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo.OPTForQuestionAnswering(model_config).to("cuda")
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_eval()
check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, start_positions=start_positions, end_positions=end_positions, use_tqdm=True)
def eval_mezo2_sgd_question_answering(model_config, zo_config, device='cuda'):
input_ids, start_positions, end_positions = prepare_data_for_question_answering(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo2.OPTForQuestionAnswering(model_config).to("cuda")
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_eval()
check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, start_positions=start_positions, end_positions=end_positions, use_tqdm=True)
def test_mezo_sgd_causalLM_training():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
model_config.max_position_embeddings = args.sequence_length
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
working_device=args.working_device)
zo_cfg.zo2 = False
train_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device)
def test_mezo2_sgd_causalLM_training():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
model_config.max_position_embeddings = args.sequence_length
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
offloading_device=args.offloading_device, working_device=args.working_device)
zo_cfg.zo2 = True
train_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device)
def test_mezo_sgd_causalLM_eval():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
model_config.max_position_embeddings = args.sequence_length
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
working_device=args.working_device)
zo_cfg.zo2 = False
eval_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device)
def test_mezo2_sgd_causalLM_eval():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
model_config.max_position_embeddings = args.sequence_length
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
offloading_device=args.offloading_device, working_device=args.working_device)
zo_cfg.zo2 = True
eval_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device)
def test_mezo_sgd_sequence_classification_training():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
model_config.max_position_embeddings = args.sequence_length
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
working_device=args.working_device)
zo_cfg.zo2 = False
train_mezo_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device)
def test_mezo2_sgd_sequence_classification_training():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
model_config.max_position_embeddings = args.sequence_length
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
offloading_device=args.offloading_device, working_device=args.working_device)
zo_cfg.zo2 = True
train_mezo2_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device)
def test_mezo_sgd_sequence_classification_eval():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
model_config.max_position_embeddings = args.sequence_length
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
working_device=args.working_device)
zo_cfg.zo2 = False
eval_mezo_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device)
def test_mezo2_sgd_sequence_classification_eval():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
model_config.max_position_embeddings = args.sequence_length
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
offloading_device=args.offloading_device, working_device=args.working_device)
zo_cfg.zo2 = True
eval_mezo2_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device)
def test_mezo_sgd_question_answering_training():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
model_config.max_position_embeddings = args.sequence_length
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
working_device=args.working_device)
zo_cfg.zo2 = False
train_mezo_sgd_question_answering(model_config, zo_cfg, device=args.working_device)
def test_mezo2_sgd_question_answering_training():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
model_config.max_position_embeddings = args.sequence_length
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
offloading_device=args.offloading_device, working_device=args.working_device)
zo_cfg.zo2 = True
train_mezo2_sgd_question_answering(model_config, zo_cfg, device=args.working_device)
def test_mezo_sgd_question_answering_eval():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
model_config.max_position_embeddings = args.sequence_length
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
working_device=args.working_device)
zo_cfg.zo2 = False
eval_mezo_sgd_question_answering(model_config, zo_cfg, device=args.working_device)
def test_mezo2_sgd_question_answering_eval():
seed_everything(args.seed)
model_configs = OPTConfigs()
model_config = getattr(model_configs, args.model_name)
model_config.tie_word_embeddings=False
model_config.max_position_embeddings = args.sequence_length
zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,
offloading_device=args.offloading_device, working_device=args.working_device)
zo_cfg.zo2 = True
eval_mezo2_sgd_question_answering(model_config, zo_cfg, device=args.working_device)
if __name__=="__main__":
args = get_args()
original_dtype = torch.get_default_dtype()
if args.zo_method == "zo":
if args.task == "causalLM":
if args.eval:
test_mezo_sgd_causalLM_eval()
else:
test_mezo_sgd_causalLM_training()
elif args.task == "sequence_classification":
if args.eval:
test_mezo_sgd_sequence_classification_eval()
else:
test_mezo_sgd_sequence_classification_training()
elif args.task == "question_answering":
if args.eval:
test_mezo_sgd_question_answering_eval()
else:
test_mezo_sgd_question_answering_training()
else:
raise NotImplementedError(f"Task {args.task} is unsupported.")
elif args.zo_method == "zo2":
if args.task == "causalLM":
if args.eval:
test_mezo2_sgd_causalLM_eval()
else:
test_mezo2_sgd_causalLM_training()
elif args.task == "sequence_classification":
if args.eval:
test_mezo2_sgd_sequence_classification_eval()
else:
test_mezo2_sgd_sequence_classification_training()
elif args.task == "question_answering":
if args.eval:
test_mezo2_sgd_question_answering_eval()
else:
test_mezo2_sgd_question_answering_training()
else:
raise NotImplementedError(f"Task {args.task} is unsupported.")
else:
raise NotImplementedError
================================================
FILE: test/mezo_sgd/hf_opt/test_speed_eval.sh
================================================
#!/bin/bash
set -e
set -o pipefail
model_names=("opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b")
task_ids=("causalLM")
# ANSI color codes
GREEN='\033[0;32m'
RED='\033[0;31m'
NC='\033[0m'
for model_name in "${model_names[@]}"
do
for task_id in "${task_ids[@]}"
do
echo "Testing model_name: $model_name, task_id: $task_id"
CMD1="python test/mezo_sgd/hf_opt/test_speed.py --model_name $model_name --task $task_id --zo_method zo --max_steps 30 --eval"
CMD2="python test/mezo_sgd/hf_opt/test_speed.py --model_name $model_name --task $task_id --zo_method zo2 --max_steps 30 --eval"
OUT1="/tmp/output1_${model_name}_${task_id}.txt"
OUT2="/tmp/output2_${model_name}_${task_id}.txt"
$CMD1 2>&1 | tee $OUT1
$CMD2 2>&1 | tee $OUT2
echo "Analyzing throughput..."
# Count the total number of lines and determine the number of iteration lines
total_lines1=$(wc -l < $OUT1)
total_lines2=$(wc -l < $OUT2)
iter_lines1=$(grep -c 'Time cost after iteration' $OUT1)
iter_lines2=$(grep -c 'Time cost after iteration' $OUT2)
# Calculate the starting line for the last 50% of iterations
start_line1=$(($total_lines1 - $iter_lines1 + $(($iter_lines1 / 2 + 1))))
start_line2=$(($total_lines2 - $iter_lines2 + $(($iter_lines2 / 2 + 1))))
# Calculate average tokens per second for the last 50% of the iterations
avg_tok_s1=$(tail -n +$start_line1 $OUT1 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}')
avg_tok_s2=$(tail -n +$start_line2 $OUT2 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}')
ratio=$(echo "scale=2; $avg_tok_s2 / $avg_tok_s1 * 100" | bc)
echo -e "Model: $model_name, Task: $task_id"
echo -e "ZO average throughput (last 50% iterations): ${GREEN}$avg_tok_s1 tok/s${NC}"
echo -e "ZO2 average throughput (last 50% iterations): ${GREEN}$avg_tok_s2 tok/s${NC}"
echo -e "Throughput ratio of ZO2 to ZO (last 50% iterations): ${GREEN}$ratio%${NC}"
rm $OUT1 $OUT2
done
done
================================================
FILE: test/mezo_sgd/hf_opt/test_speed_train.sh
================================================
#!/bin/bash
set -e
set -o pipefail
model_names=("opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b")
task_ids=("causalLM")
# ANSI color codes
GREEN='\033[0;32m'
RED='\033[0;31m'
NC='\033[0m'
for model_name in "${model_names[@]}"
do
for task_id in "${task_ids[@]}"
do
echo "Testing model_name: $model_name, task_id: $task_id"
CMD1="python test/mezo_sgd/hf_opt/test_speed.py --model_name $model_name --task $task_id --zo_method zo --max_steps 30"
CMD2="python test/mezo_sgd/hf_opt/test_speed.py --model_name $model_name --task $task_id --zo_method zo2 --max_steps 30"
OUT1="/tmp/output1_${model_name}_${task_id}.txt"
OUT2="/tmp/output2_${model_name}_${task_id}.txt"
$CMD1 2>&1 | tee $OUT1
$CMD2 2>&1 | tee $OUT2
echo "Analyzing throughput..."
# Count the total number of lines and determine the number of iteration lines
total_lines1=$(wc -l < $OUT1)
total_lines2=$(wc -l < $OUT2)
iter_lines1=$(grep -c 'Time cost after iteration' $OUT1)
iter_lines2=$(grep -c 'Time cost after iteration' $OUT2)
# Calculate the starting line for the last 50% of iterations
start_line1=$(($total_lines1 - $iter_lines1 + $(($iter_lines1 / 2 + 1))))
start_line2=$(($total_lines2 - $iter_lines2 + $(($iter_lines2 / 2 + 1))))
# Calculate average tokens per second for the last 50% of the iterations
avg_tok_s1=$(tail -n +$start_line1 $OUT1 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}')
avg_tok_s2=$(tail -n +$start_line2 $OUT2 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}')
ratio=$(echo "scale=2; $avg_tok_s2 / $avg_tok_s1 * 100" | bc)
echo -e "Model: $model_name, Task: $task_id"
echo -e "ZO average throughput (last 50% iterations): ${GREEN}$avg_tok_s1 tok/s${NC}"
echo -e "ZO2 average throughput (last 50% iterations): ${GREEN}$avg_tok_s2 tok/s${NC}"
echo -e "Throughput ratio of ZO2 to ZO (last 50% iterations): ${GREEN}$ratio%${NC}"
rm $OUT1 $OUT2
done
done
================================================
FILE: test/mezo_sgd/hf_opt/utils.py
================================================
# Copyright (c) 2025 liangyuwang
# Licensed under the Apache License, Version 2.0
import torch
import time
import argparse
from tqdm import tqdm
import psutil
import os
from transformers import OPTConfig
import pynvml
def get_args():
args = argparse.ArgumentParser()
args.add_argument("--zo_method", type=str, default="zo2")
args.add_argument("--eval", action="store_true")
args.add_argument("--task", type=str, default="causalLM")
args.add_argument("--model_name", type=str, default="opt_125m")
args.add_argument("--model_dtype", type=str, default="fp16")
args.add_argument("--verbose", action="store_true")
args.add_argument("--max_steps", type=int, default=3)
args.add_argument("--lr", type=float, default=1e-3)
args.add_argument("--weight_decay", type=float, default=1e-1)
args.add_argument("--zo_eps", type=float, default=1e-3)
args.add_argument("--seed", type=int, default=42)
args.add_argument("--batch_size", type=int, default=1)
args.add_argument("--sequence_length", type=int, default=2048)
args.add_argument("--overlap", type=str, default="all")
args.add_argument("--offloading_device", type=str, default="cpu")
args.add_argument("--working_device", type=str, default="cuda:0")
args = args.parse_args()
args.model_dtype = dtype_lookup[args.model_dtype]
return args
class OPTConfigs:
opt_125m: OPTConfig = OPTConfig(num_hidden_layers=12, num_attention_heads=12, hidden_size=768, ffn_dim=3072, max_position_embeddings=2048)
opt_350m: OPTConfig = OPTConfig(num_hidden_layers=24, num_attention_heads=16, hidden_size=1024, ffn_dim=4096, max_position_embeddings=2048)
opt_1_3b: OPTConfig = OPTConfig(num_hidden_layers=24, num_attention_heads=32, hidden_size=2048, ffn_dim=8192, max_position_embeddings=2048)
opt_2_7b: OPTConfig = OPTConfig(num_hidden_layers=32, num_attention_heads=32, hidden_size=2560, ffn_dim=10240, max_position_embeddings=2048)
opt_6_7b: OPTConfig = OPTConfig(num_hidden_layers=32, num_attention_heads=32, hidden_size=4096, ffn_dim=16384, max_position_embeddings=2048)
opt_13b: OPTConfig = OPTConfig(num_hidden_layers=40, num_attention_heads=40, hidden_size=5120, ffn_dim=20480, max_position_embeddings=2048)
opt_30b: OPTConfig = OPTConfig(num_hidden_layers=48, num_attention_heads=56, hidden_size=7168, ffn_dim=28672, max_position_embeddings=2048)
opt_66b: OPTConfig = OPTConfig(num_hidden_layers=64, num_attention_heads=72, hidden_size=9216, ffn_dim=36864, max_position_embeddings=2048)
opt_175b: OPTConfig = OPTConfig(num_hidden_layers=96, num_attention_heads=96, hidden_size=12288, ffn_dim=49152, max_position_embeddings=2048)
dtype_lookup = {
"fp64": torch.float64,
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16
}
def model_size(model: torch.nn.Module):
total_size = sum(p.numel() for p in model.parameters())
trainable_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
return {"total": total_size, "trainable": trainable_size}
def prepare_data_for_causalLM(V, B, T, device='cuda'):
data_batch = torch.randint(0, V, (B, T)).to(device)
input_ids = data_batch
labels = data_batch
return input_ids, labels
def prepare_data_for_sequence_classification(V, B, T, device='cuda'):
input_ids = torch.randint(0, V, (B, T)).to(device)
labels = torch.randint(0, 1, (B, )).to(device)
return input_ids, labels
def prepare_data_for_question_answering(V, B, T, device='cuda'):
input_ids = torch.randint(0, V, (B, T)).to(device)
start_positions = torch.randint(0, 1, (B, )).to(device)
end_positions = torch.randint(1, 2, (B, )).to(device)
return input_ids, start_positions, end_positions
# GPU Memory Monitoring
pynvml.nvmlInit()
def check_peak_gpu_memory_usage(iter, device=0, use_tqdm=False):
# Check the peak memory usage
handle = pynvml.nvmlDeviceGetHandleByIndex(device) # Adjust index if necessary
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
peak_memory = info.used / 1024**2
if use_tqdm:
tqdm.write("Peak GPU Memory after iteration {}: {:.2f} MB".format(iter+1, peak_memory))
else:
print(f"Peak GPU Memory after iteration {iter+1}: {peak_memory:.2f} MB")
# CPU Memory Monitoring
peak_memory_cpu = 0
def check_and_update_peak_cpu_memory_usage(iter, use_tqdm=False):
global peak_memory_cpu
process = psutil.Process(os.getpid())
current_memory = process.memory_info().rss / (1024 ** 2) # Convert to MB
if current_memory > peak_memory_cpu:
peak_memory_cpu = current_memory
if use_tqdm:
tqdm.write(f"Peak CPU Memory after iteration {iter+1}: {peak_memory_cpu:.2f} MB")
else:
print(f"Peak CPU Memory after iteration {iter+1}: {peak_memory_cpu:.2f} MB")
def reset_peak_cpu_memory_usage():
global peak_memory_cpu
peak_memory_cpu = 0
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
def check_throughput(iter, total_token_batch_size_per_iter, fn, *args, use_tqdm=False, **kwargs):
t1 = time.time()
out = fn(*args, **kwargs)
t2 = time.time()
time_cost = t2-t1
throughtput = total_token_batch_size_per_iter / time_cost
if use_tqdm:
tqdm.write("Time cost after iteration {}: {:.2f} ms, {:.2f} tok/s".format(iter+1, time_cost*1e3, throughtput))
else:
print("Time cost after iteration {}: {:.2f} ms, {:.2f} tok/s".format(iter+1, time_cost*1e3, throughtput))
================================================
FILE: test/mezo_sgd/hf_qwen3/record_zo2_memory.sh
================================================
#!/bin/bash
set -e
set -o pipefail
model_names=("qwen3_0_6b" "qwen3_1_7b" "qwen3_4b" "qwen3_8b" "qwen3_14b" "qwen3_32b")
task_ids=("causalLM")
# ANSI color codes
GREEN='\033[0;32m'
RED='\033[0;31m'
NC='\033[0m'
for model_name in "${model_names[@]}"
do
for task_id in "${task_ids[@]}"
do
echo "Testing model_name: $model_name, task_id: $task_id"
CMD2="python test/mezo_sgd/hf_qwen3/test_memory.py --model_name $model_name --zo_method zo2 --max_steps 30"
OUT2="/tmp/output2_${model_name}_${task_id}.txt"
$CMD2 2>&1 | tee $OUT2
echo "Recording Peak GPU and CPU Memory usage..."
max_mem1=$(grep 'Peak GPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1)
max_mem2=$(grep 'Peak CPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1)
if [ -z "$max_mem1" ] || [ -z "$max_mem2" ]; then
echo "Could not find memory usage data in the output."
else
echo -e "Model: $model_name, Task: $task_id"
echo -e "ZO2 peak GPU memory: ${GREEN}$max_mem1 MB${NC}"
echo -e "ZO2 peak CPU memory: ${GREEN}$max_mem2 MB${NC}"
fi
rm $OUT2
done
done
================================================
FILE: test/mezo_sgd/hf_qwen3/record_zo2_speed.sh
================================================
#!/bin/bash
set -e
set -o pipefail
model_names=("qwen3_0_6b" "qwen3_1_7b" "qwen3_4b" "qwen3_8b" "qwen3_14b" "qwen3_32b")
task_ids=("causalLM")
# ANSI color codes
GREEN='\033[0;32m'
RED='\033[0;31m'
NC='\033[0m'
for model_name in "${model_names[@]}"
do
for task_id in "${task_ids[@]}"
do
echo "Testing model_name: $model_name, task_id: $task_id"
CMD2="python test/mezo_sgd/hf_qwen3/test_speed.py --model_name $model_name --zo_method zo2 --max_steps 30"
OUT2="/tmp/output2_${model_name}_${task_id}.txt"
$CMD2 2>&1 | tee $OUT2
echo "Recording throughput..."
# Count the total number of lines and determine the number of iteration lines
total_lines2=$(wc -l < $OUT2)
iter_lines2=$(grep -c 'Time cost after iteration' $OUT2)
# Calculate the starting line for the last 50% of iterations
start_line2=$(($total_lines2 - $iter_lines2 + $(($iter_lines2 / 2 + 1))))
# Calculate average tokens per second for the last 50% of the iterations
avg_tok_s2=$(tail -n +$start_line2 $OUT2 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}')
echo -e "Model: $model_name, Task: $task_id"
echo -e "ZO2 average throughput (last 50% iterations): ${GREEN}$avg_tok_s2 tok/s${NC}"
rm $OUT2
done
done
================================================
FILE: test/mezo_sgd/hf_qwen3/test_acc.py
================================================
# Copyright (c) 2025 liangyuwang
# Licensed under the Apache License, Version 2.0
import sys
sys.path.append("../zo2")
import torch
from tqdm import tqdm
from zo2.config.mezo_sgd import MeZOSGDConfig
from zo2.model.huggingface.qwen3.mezo_sgd import zo, zo2
from zo2.utils.utils import seed_everything
from utils import (
Qwen3Configs,
prepare_data_for_causalLM,
model_size,
get_args
)
def train_mezo_sgd_causalLM(model_config, zo_config, device='cuda'):
input_ids, labels = prepare_data_for_causalLM(
model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)
torch.set_default_dtype(args.model_dtype)
model = zo.Qwen3ForCausalLM(model_config).to(device)
model.zo_init(zo_config)
total_parameters = model_size(model)["total"]
print(f"model size: {total_parameters/1024**3:.2f} B")
torch.set_default_dtype(original_dtype)
for i in tqdm(range(args.max_steps)):
model.zo_eval()
model.zo_train()
loss = model(input_ids=input_ids, labels=labels)
res = "Iteration {}, loss: {}, projected grad: {}"
tqdm.write(res.format(i, loss, model.opt.projected_grad))
def train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'):
input_ids, labels = prepar
gitextract_z35rb3tr/
├── LICENSE
├── README.md
├── env.yml
├── example/
│ ├── demo/
│ │ └── train_zo2_with_hf_opt.py
│ └── mezo_runner/
│ ├── README.md
│ ├── metrics.py
│ ├── mezo.sh
│ ├── run.py
│ ├── tasks.py
│ ├── templates.py
│ └── utils.py
├── requirements.txt
├── script/
│ ├── add-copyright.py
│ └── clear-pycache.sh
├── setup.py
├── test/
│ ├── README.md
│ └── mezo_sgd/
│ ├── hf_gpt/
│ │ └── trainer.py
│ ├── hf_llama/
│ │ └── trainer.py
│ ├── hf_opt/
│ │ ├── record_zo2_memory.sh
│ │ ├── record_zo2_speed.sh
│ │ ├── test_acc.py
│ │ ├── test_acc_eval.sh
│ │ ├── test_acc_train.sh
│ │ ├── test_memory.py
│ │ ├── test_memory_eval.sh
│ │ ├── test_memory_train.sh
│ │ ├── test_scheduler_acc_eval.sh
│ │ ├── test_scheduler_acc_train.sh
│ │ ├── test_speed.py
│ │ ├── test_speed_eval.sh
│ │ ├── test_speed_train.sh
│ │ └── utils.py
│ ├── hf_qwen3/
│ │ ├── record_zo2_memory.sh
│ │ ├── record_zo2_speed.sh
│ │ ├── test_acc.py
│ │ ├── test_acc_eval.sh
│ │ ├── test_acc_train.sh
│ │ ├── test_memory.py
│ │ ├── test_memory_train.sh
│ │ ├── test_speed.py
│ │ ├── test_speed_train.sh
│ │ └── utils.py
│ └── nanogpt/
│ ├── record_zo2_memory.sh
│ ├── record_zo2_speed.sh
│ ├── test_acc.py
│ ├── test_acc_eval.sh
│ ├── test_acc_train.sh
│ ├── test_memory.py
│ ├── test_memory_eval.sh
│ ├── test_memory_train.sh
│ ├── test_speed.py
│ ├── test_speed_eval.sh
│ ├── test_speed_train.sh
│ └── utils.py
├── tutorial/
│ ├── README.md
│ ├── colab.ipynb
│ ├── demo.ipynb
│ ├── huggingface.ipynb
│ └── nanogpt.ipynb
└── zo2/
├── README.md
├── __init__.py
├── config/
│ ├── __init__.py
│ └── mezo_sgd.py
├── model/
│ ├── __init__.py
│ ├── base.py
│ ├── huggingface/
│ │ ├── __init__.py
│ │ ├── gpt/
│ │ │ └── mezo_sgd/
│ │ │ ├── zo.py
│ │ │ └── zo2.py
│ │ ├── llama/
│ │ │ └── mezo_sgd/
│ │ │ ├── zo.py
│ │ │ └── zo2.py
│ │ ├── opt/
│ │ │ ├── __init__.py
│ │ │ └── mezo_sgd/
│ │ │ ├── __init__.py
│ │ │ ├── utils.py
│ │ │ ├── zo.py
│ │ │ └── zo2.py
│ │ ├── qwen3/
│ │ │ ├── __init__.py
│ │ │ └── mezo_sgd/
│ │ │ ├── __init__.py
│ │ │ ├── utils.py
│ │ │ ├── zo.py
│ │ │ └── zo2.py
│ │ └── zo_init.py
│ └── nanogpt/
│ ├── __init__.py
│ ├── mezo_sgd/
│ │ ├── __init__.py
│ │ ├── zo.py
│ │ └── zo2.py
│ └── model.py
├── optimizer/
│ ├── __init__.py
│ ├── base.py
│ └── mezo_sgd/
│ ├── __init__.py
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── com.py
│ │ └── comm.py
│ ├── zo.py
│ └── zo2.py
├── trainer/
│ ├── __init__.py
│ ├── hf_transformers/
│ │ ├── __init__.py
│ │ └── trainer.py
│ └── hf_trl/
│ ├── __init__.py
│ └── sft_trainer.py
└── utils/
├── __init__.py
└── utils.py
SYMBOL INDEX (574 symbols across 44 files)
FILE: example/mezo_runner/metrics.py
function normalize_answer (line 10) | def normalize_answer(s):
function calculate_metric (line 29) | def calculate_metric(predictions, metric_name):
function f1 (line 62) | def f1(pred, gold):
FILE: example/mezo_runner/run.py
class OurArguments (line 40) | class OurArguments(TrainingArguments):
function parse_args (line 125) | def parse_args():
function set_seed (line 133) | def set_seed(seed: int):
class Framework (line 140) | class Framework:
method __init__ (line 142) | def __init__(self, args, task):
method load_model (line 148) | def load_model(self):
method forward (line 272) | def forward(self, input_ids, option_len=None, generation=False):
method one_step_pred (line 305) | def one_step_pred(self, train_samples, eval_sample, verbose=False):
method evaluate (line 382) | def evaluate(self, train_samples, eval_samples, one_train_set_per_eval...
method train (line 404) | def train(self, train_samples, eval_samples):
function result_file_tag (line 524) | def result_file_tag(args):
function main (line 538) | def main():
FILE: example/mezo_runner/tasks.py
function get_task (line 26) | def get_task(task_name):
class Sample (line 39) | class Sample:
class Dataset (line 46) | class Dataset:
method __init__ (line 51) | def __init__(self, subtask=None, **kwargs) -> None:
method get_task_name (line 54) | def get_task_name(self):
method load_dataset (line 57) | def load_dataset():
method get_template (line 60) | def get_template(self, template_version=0):
method build_sample (line 64) | def build_sample(self, example):
method sample_train_sets (line 67) | def sample_train_sets(self, num_train=32, num_dev=None, num_eval=None,...
method sample_subset (line 98) | def sample_subset(self, data_split="train", seed=0, num=100, exclude=N...
method valid_samples (line 110) | def valid_samples(self):
class SST2Dataset (line 114) | class SST2Dataset(Dataset):
method __init__ (line 116) | def __init__(self, subtask=None, **kwargs) -> None:
method load_dataset (line 119) | def load_dataset(self, path, **kwargs):
method build_sample (line 130) | def build_sample(self, example):
method get_template (line 134) | def get_template(self, template_version=0):
class CopaDataset (line 138) | class CopaDataset(Dataset):
method __init__ (line 142) | def __init__(self, subtask=None, **kwargs) -> None:
method load_dataset (line 145) | def load_dataset(self, path, **kwargs):
method build_sample (line 154) | def build_sample(self, example):
method get_template (line 165) | def get_template(self, template_version=0):
class BoolQDataset (line 169) | class BoolQDataset(Dataset):
method __init__ (line 170) | def __init__(self, subtask=None, **kwargs) -> None:
method load_dataset (line 173) | def load_dataset(self, path, **kwargs):
method build_sample (line 182) | def build_sample(self, example):
method get_template (line 192) | def get_template(self, template_version=2):
class MultiRCDataset (line 196) | class MultiRCDataset(Dataset):
method __init__ (line 198) | def __init__(self, subtask=None, **kwargs) -> None:
method load_dataset (line 201) | def load_dataset(self, path, **kwargs):
method build_sample (line 210) | def build_sample(self, example):
method get_template (line 220) | def get_template(self, template_version=0):
class CBDataset (line 224) | class CBDataset(Dataset):
method __init__ (line 226) | def __init__(self, subtask=None, **kwargs) -> None:
method load_dataset (line 229) | def load_dataset(self, path, **kwargs):
method build_sample (line 238) | def build_sample(self, example):
method get_template (line 248) | def get_template(self, template_version=0):
class WICDataset (line 252) | class WICDataset(Dataset):
method __init__ (line 254) | def __init__(self, subtask=None, **kwargs) -> None:
method load_dataset (line 257) | def load_dataset(self, path, **kwargs):
method build_sample (line 266) | def build_sample(self, example):
method get_template (line 276) | def get_template(self, template_version=0):
class WSCDataset (line 280) | class WSCDataset(Dataset):
method __init__ (line 282) | def __init__(self, subtask=None, **kwargs) -> None:
method load_dataset (line 285) | def load_dataset(self, path, **kwargs):
method build_sample (line 294) | def build_sample(self, example):
method get_template (line 304) | def get_template(self, template_version=0):
class ReCoRDDataset (line 308) | class ReCoRDDataset(Dataset):
method __init__ (line 310) | def __init__(self, subtask=None, **kwargs) -> None:
method load_dataset (line 313) | def load_dataset(self, path, **kwargs):
method build_sample (line 322) | def build_sample(self, example):
method get_template (line 332) | def get_template(self, template_version=0):
class RTEDataset (line 336) | class RTEDataset(Dataset):
method __init__ (line 338) | def __init__(self, subtask=None, **kwargs) -> None:
method load_dataset (line 341) | def load_dataset(self, path, **kwargs):
method build_sample (line 350) | def build_sample(self, example):
method get_template (line 360) | def get_template(self, template_version=0):
class SQuADDataset (line 364) | class SQuADDataset(Dataset):
method __init__ (line 368) | def __init__(self, subtask=None, **kwargs) -> None:
method load_dataset (line 371) | def load_dataset(self):
method build_sample (line 381) | def build_sample(self, example, idx):
method get_template (line 396) | def get_template(self, template_version=0):
class DROPDataset (line 400) | class DROPDataset(Dataset):
method __init__ (line 404) | def __init__(self, subtask=None, **kwargs) -> None:
method load_dataset (line 407) | def load_dataset(self):
method build_sample (line 417) | def build_sample(self, example, idx):
method get_template (line 431) | def get_template(self, template_version=0):
FILE: example/mezo_runner/templates.py
class Template (line 8) | class Template:
method encode (line 9) | def encode(self, sample):
method verbalize (line 15) | def verbalize(self, sample, candidate):
method encode_sfc (line 21) | def encode_sfc(self, sample):
method verbalize_sfc (line 27) | def verbalize_sfc(self, sample, candidate):
class SST2Template (line 34) | class SST2Template(Template):
method encode (line 36) | def encode(self, sample):
method verbalize (line 40) | def verbalize(self, sample, candidate):
method encode_sfc (line 44) | def encode_sfc(self, sample):
method verbalize_sfc (line 47) | def verbalize_sfc(self, sample, candidate):
class CopaTemplate (line 51) | class CopaTemplate(Template):
method get_conjucture (line 56) | def get_conjucture(self, sample):
method get_prompt (line 65) | def get_prompt(self, sample):
method encode (line 77) | def encode(self, sample):
method capitalize (line 81) | def capitalize(self, c):
method verbalize (line 96) | def verbalize(self, sample, candidate):
method encode_sfc (line 100) | def encode_sfc(self, sample):
method verbalize_sfc (line 104) | def verbalize_sfc(self, sample, candidate):
class BoolQTemplate (line 110) | class BoolQTemplate(Template):
method encode (line 111) | def encode(self, sample):
method verbalize (line 119) | def verbalize(self, sample, candidate):
method encode_sfc (line 127) | def encode_sfc(self, sample):
method verbalize_sfc (line 130) | def verbalize_sfc(self, sample, candidate):
class BoolQTemplateV2 (line 134) | class BoolQTemplateV2(Template):
method encode (line 135) | def encode(self, sample):
method verbalize (line 143) | def verbalize(self, sample, candidate):
method encode_sfc (line 151) | def encode_sfc(self, sample):
method verbalize_sfc (line 154) | def verbalize_sfc(self, sample, candidate):
class BoolQTemplateV3 (line 158) | class BoolQTemplateV3(Template):
method encode (line 159) | def encode(self, sample):
method verbalize (line 167) | def verbalize(self, sample, candidate):
method encode_sfc (line 175) | def encode_sfc(self, sample):
method verbalize_sfc (line 178) | def verbalize_sfc(self, sample, candidate):
class MultiRCTemplate (line 182) | class MultiRCTemplate(Template):
method encode (line 186) | def encode(self, sample):
method verbalize (line 192) | def verbalize(self, sample, candidate):
method encode_sfc (line 198) | def encode_sfc(self, sample):
method verbalize_sfc (line 201) | def verbalize_sfc(self, sample, candidate):
class CBTemplate (line 205) | class CBTemplate(Template):
method encode (line 209) | def encode(self, sample):
method verbalize (line 214) | def verbalize(self, sample, candidate):
method encode_sfc (line 219) | def encode_sfc(self, sample):
method verbalize_sfc (line 222) | def verbalize_sfc(self, sample, candidate):
class WICTemplate (line 226) | class WICTemplate(Template):
method encode (line 230) | def encode(self, sample):
method verbalize (line 236) | def verbalize(self, sample, candidate):
method encode_sfc (line 242) | def encode_sfc(self, sample):
method verbalize_sfc (line 245) | def verbalize_sfc(self, sample, candidate):
class WSCTemplate (line 249) | class WSCTemplate(Template):
method encode (line 253) | def encode(self, sample):
method verbalize (line 259) | def verbalize(self, sample, candidate):
method encode_sfc (line 265) | def encode_sfc(self, sample):
method verbalize_sfc (line 268) | def verbalize_sfc(self, sample, candidate):
class ReCoRDTemplate (line 272) | class ReCoRDTemplate(Template):
method encode (line 275) | def encode(self, sample):
method verbalize (line 280) | def verbalize(self, sample, candidate):
method encode_sfc (line 285) | def encode_sfc(self, sample):
method verbalize_sfc (line 288) | def verbalize_sfc(self, sample, candidate):
class ReCoRDTemplateGPT3 (line 292) | class ReCoRDTemplateGPT3(Template):
method encode (line 295) | def encode(self, sample):
method verbalize (line 299) | def verbalize(self, sample, candidate):
method encode_sfc (line 308) | def encode_sfc(self, sample):
method verbalize_sfc (line 311) | def verbalize_sfc(self, sample, candidate):
class RTETemplate (line 316) | class RTETemplate(Template):
method encode (line 320) | def encode(self, sample):
method verbalize (line 325) | def verbalize(self, sample, candidate):
method encode_sfc (line 330) | def encode_sfc(self, sample):
method verbalize_sfc (line 333) | def verbalize_sfc(self, sample, candidate):
class SQuADv2Template (line 337) | class SQuADv2Template(Template):
method encode (line 339) | def encode(self, sample):
method verbalize (line 347) | def verbalize(self, sample, candidate):
method encode_sfc (line 356) | def encode_sfc(self, sample):
method verbalize_sfc (line 359) | def verbalize_sfc(self, sample, candidate):
class DROPTemplate (line 363) | class DROPTemplate(Template):
method encode (line 365) | def encode(self, sample):
method verbalize (line 373) | def verbalize(self, sample, candidate):
method encode_sfc (line 382) | def encode_sfc(self, sample):
method verbalize_sfc (line 385) | def verbalize_sfc(self, sample, candidate):
FILE: example/mezo_runner/utils.py
function custom_loss_fn_with_option_len (line 37) | def custom_loss_fn_with_option_len(self, input_ids, logits, labels, opti...
function forward_wrap_with_option_len (line 87) | def forward_wrap_with_option_len(self, input_ids=None, labels=None, opti...
function encode_prompt (line 161) | def encode_prompt(task, template, train_samples, eval_sample, tokenizer,...
class ICLCollator (line 233) | class ICLCollator:
method __call__ (line 239) | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
class DataCollatorWithPaddingAndNesting (line 260) | class DataCollatorWithPaddingAndNesting:
method __call__ (line 271) | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
class NondiffCollator (line 290) | class NondiffCollator(DataCollatorMixin):
method torch_call (line 301) | def torch_call(self, features):
class SIGUSR1Callback (line 344) | class SIGUSR1Callback(transformers.TrainerCallback):
method __init__ (line 350) | def __init__(self) -> None:
method handle_signal (line 357) | def handle_signal(self, signum, frame):
method on_step_end (line 361) | def on_step_end(self, args, state, control, **kwargs):
method on_train_end (line 366) | def on_train_end(self, args, state, control, **kwargs):
class Prediction (line 372) | class Prediction:
function count_time (line 378) | def count_time(name):
function temp_seed (line 388) | def temp_seed(seed):
class EnhancedJSONEncoder (line 397) | class EnhancedJSONEncoder(json.JSONEncoder):
method default (line 398) | def default(self, o):
function write_predictions_to_file (line 404) | def write_predictions_to_file(final_preds, output):
function write_metrics_to_file (line 410) | def write_metrics_to_file(metrics, output):
FILE: script/add-copyright.py
function add_license_header (line 9) | def add_license_header(file_path, comment_style):
FILE: test/mezo_sgd/hf_opt/test_acc.py
function train_mezo_sgd_causalLM (line 22) | def train_mezo_sgd_causalLM(model_config, zo_config, device='cuda'):
function train_mezo2_sgd_causalLM (line 38) | def train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'):
function eval_mezo_sgd_causalLM (line 54) | def eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda'):
function eval_mezo2_sgd_causalLM (line 70) | def eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'):
function train_mezo_sgd_sequence_classification (line 87) | def train_mezo_sgd_sequence_classification(model_config, zo_config, devi...
function train_mezo2_sgd_sequence_classification (line 103) | def train_mezo2_sgd_sequence_classification(model_config, zo_config, dev...
function eval_mezo_sgd_sequence_classification (line 119) | def eval_mezo_sgd_sequence_classification(model_config, zo_config, devic...
function eval_mezo2_sgd_sequence_classification (line 135) | def eval_mezo2_sgd_sequence_classification(model_config, zo_config, devi...
function train_mezo_sgd_question_answering (line 152) | def train_mezo_sgd_question_answering(model_config, zo_config, device='c...
function train_mezo2_sgd_question_answering (line 168) | def train_mezo2_sgd_question_answering(model_config, zo_config, device='...
function eval_mezo_sgd_question_answering (line 184) | def eval_mezo_sgd_question_answering(model_config, zo_config, device='cu...
function eval_mezo2_sgd_question_answering (line 200) | def eval_mezo2_sgd_question_answering(model_config, zo_config, device='c...
function test_mezo_sgd_causalLM_training (line 217) | def test_mezo_sgd_causalLM_training():
function test_mezo2_sgd_causalLM_training (line 227) | def test_mezo2_sgd_causalLM_training():
function test_mezo_sgd_causalLM_eval (line 238) | def test_mezo_sgd_causalLM_eval():
function test_mezo2_sgd_causalLM_eval (line 248) | def test_mezo2_sgd_causalLM_eval():
function test_mezo_sgd_sequence_classification_training (line 260) | def test_mezo_sgd_sequence_classification_training():
function test_mezo2_sgd_sequence_classification_training (line 270) | def test_mezo2_sgd_sequence_classification_training():
function test_mezo_sgd_sequence_classification_eval (line 281) | def test_mezo_sgd_sequence_classification_eval():
function test_mezo2_sgd_sequence_classification_eval (line 291) | def test_mezo2_sgd_sequence_classification_eval():
function test_mezo_sgd_question_answering_training (line 303) | def test_mezo_sgd_question_answering_training():
function test_mezo2_sgd_question_answering_training (line 313) | def test_mezo2_sgd_question_answering_training():
function test_mezo_sgd_question_answering_eval (line 324) | def test_mezo_sgd_question_answering_eval():
function test_mezo2_sgd_question_answering_eval (line 334) | def test_mezo2_sgd_question_answering_eval():
FILE: test/mezo_sgd/hf_opt/test_memory.py
function train_mezo_sgd_causalLM (line 25) | def train_mezo_sgd_causalLM(model_config, zo_config, device='cuda:0'):
function train_mezo2_sgd_causalLM (line 42) | def train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda:0'):
function eval_mezo_sgd_causalLM (line 59) | def eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda:0'):
function eval_mezo2_sgd_causalLM (line 76) | def eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda:0'):
function train_mezo_sgd_sequence_classification (line 94) | def train_mezo_sgd_sequence_classification(model_config, zo_config, devi...
function train_mezo2_sgd_sequence_classification (line 111) | def train_mezo2_sgd_sequence_classification(model_config, zo_config, dev...
function eval_mezo_sgd_sequence_classification (line 128) | def eval_mezo_sgd_sequence_classification(model_config, zo_config, devic...
function eval_mezo2_sgd_sequence_classification (line 145) | def eval_mezo2_sgd_sequence_classification(model_config, zo_config, devi...
function train_mezo_sgd_question_answering (line 163) | def train_mezo_sgd_question_answering(model_config, zo_config, device='c...
function train_mezo2_sgd_question_answering (line 180) | def train_mezo2_sgd_question_answering(model_config, zo_config, device='...
function eval_mezo_sgd_question_answering (line 197) | def eval_mezo_sgd_question_answering(model_config, zo_config, device='cu...
function eval_mezo2_sgd_question_answering (line 214) | def eval_mezo2_sgd_question_answering(model_config, zo_config, device='c...
function test_mezo_sgd_causalLM_training (line 232) | def test_mezo_sgd_causalLM_training():
function test_mezo2_sgd_causalLM_training (line 243) | def test_mezo2_sgd_causalLM_training():
function test_mezo_sgd_causalLM_eval (line 254) | def test_mezo_sgd_causalLM_eval():
function test_mezo2_sgd_causalLM_eval (line 265) | def test_mezo2_sgd_causalLM_eval():
function test_mezo_sgd_sequence_classification_training (line 277) | def test_mezo_sgd_sequence_classification_training():
function test_mezo2_sgd_sequence_classification_training (line 288) | def test_mezo2_sgd_sequence_classification_training():
function test_mezo_sgd_sequence_classification_eval (line 299) | def test_mezo_sgd_sequence_classification_eval():
function test_mezo2_sgd_sequence_classification_eval (line 310) | def test_mezo2_sgd_sequence_classification_eval():
function test_mezo_sgd_question_answering_training (line 322) | def test_mezo_sgd_question_answering_training():
function test_mezo2_sgd_question_answering_training (line 333) | def test_mezo2_sgd_question_answering_training():
function test_mezo_sgd_question_answering_eval (line 344) | def test_mezo_sgd_question_answering_eval():
function test_mezo2_sgd_question_answering_eval (line 355) | def test_mezo2_sgd_question_answering_eval():
FILE: test/mezo_sgd/hf_opt/test_speed.py
function train_mezo_sgd_causalLM (line 23) | def train_mezo_sgd_causalLM(model_config, zo_config, device='cuda'):
function train_mezo2_sgd_causalLM (line 36) | def train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'):
function eval_mezo_sgd_causalLM (line 49) | def eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda'):
function eval_mezo2_sgd_causalLM (line 62) | def eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'):
function train_mezo_sgd_sequence_classification (line 76) | def train_mezo_sgd_sequence_classification(model_config, zo_config, devi...
function train_mezo2_sgd_sequence_classification (line 89) | def train_mezo2_sgd_sequence_classification(model_config, zo_config, dev...
function eval_mezo_sgd_sequence_classification (line 102) | def eval_mezo_sgd_sequence_classification(model_config, zo_config, devic...
function eval_mezo2_sgd_sequence_classification (line 115) | def eval_mezo2_sgd_sequence_classification(model_config, zo_config, devi...
function train_mezo_sgd_question_answering (line 129) | def train_mezo_sgd_question_answering(model_config, zo_config, device='c...
function train_mezo2_sgd_question_answering (line 142) | def train_mezo2_sgd_question_answering(model_config, zo_config, device='...
function eval_mezo_sgd_question_answering (line 155) | def eval_mezo_sgd_question_answering(model_config, zo_config, device='cu...
function eval_mezo2_sgd_question_answering (line 168) | def eval_mezo2_sgd_question_answering(model_config, zo_config, device='c...
function test_mezo_sgd_causalLM_training (line 182) | def test_mezo_sgd_causalLM_training():
function test_mezo2_sgd_causalLM_training (line 193) | def test_mezo2_sgd_causalLM_training():
function test_mezo_sgd_causalLM_eval (line 204) | def test_mezo_sgd_causalLM_eval():
function test_mezo2_sgd_causalLM_eval (line 215) | def test_mezo2_sgd_causalLM_eval():
function test_mezo_sgd_sequence_classification_training (line 227) | def test_mezo_sgd_sequence_classification_training():
function test_mezo2_sgd_sequence_classification_training (line 238) | def test_mezo2_sgd_sequence_classification_training():
function test_mezo_sgd_sequence_classification_eval (line 249) | def test_mezo_sgd_sequence_classification_eval():
function test_mezo2_sgd_sequence_classification_eval (line 260) | def test_mezo2_sgd_sequence_classification_eval():
function test_mezo_sgd_question_answering_training (line 272) | def test_mezo_sgd_question_answering_training():
function test_mezo2_sgd_question_answering_training (line 283) | def test_mezo2_sgd_question_answering_training():
function test_mezo_sgd_question_answering_eval (line 294) | def test_mezo_sgd_question_answering_eval():
function test_mezo2_sgd_question_answering_eval (line 305) | def test_mezo2_sgd_question_answering_eval():
FILE: test/mezo_sgd/hf_opt/utils.py
function get_args (line 13) | def get_args():
class OPTConfigs (line 36) | class OPTConfigs:
function model_size (line 56) | def model_size(model: torch.nn.Module):
function prepare_data_for_causalLM (line 62) | def prepare_data_for_causalLM(V, B, T, device='cuda'):
function prepare_data_for_sequence_classification (line 68) | def prepare_data_for_sequence_classification(V, B, T, device='cuda'):
function prepare_data_for_question_answering (line 73) | def prepare_data_for_question_answering(V, B, T, device='cuda'):
function check_peak_gpu_memory_usage (line 82) | def check_peak_gpu_memory_usage(iter, device=0, use_tqdm=False):
function check_and_update_peak_cpu_memory_usage (line 94) | def check_and_update_peak_cpu_memory_usage(iter, use_tqdm=False):
function reset_peak_cpu_memory_usage (line 105) | def reset_peak_cpu_memory_usage():
function check_throughput (line 112) | def check_throughput(iter, total_token_batch_size_per_iter, fn, *args, u...
FILE: test/mezo_sgd/hf_qwen3/test_acc.py
function train_mezo_sgd_causalLM (line 20) | def train_mezo_sgd_causalLM(model_config, zo_config, device='cuda'):
function train_mezo2_sgd_causalLM (line 36) | def train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'):
function eval_mezo_sgd_causalLM (line 52) | def eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda'):
function eval_mezo2_sgd_causalLM (line 68) | def eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'):
function test_mezo_sgd_causalLM_training (line 85) | def test_mezo_sgd_causalLM_training():
function test_mezo2_sgd_causalLM_training (line 97) | def test_mezo2_sgd_causalLM_training():
function test_mezo_sgd_causalLM_eval (line 110) | def test_mezo_sgd_causalLM_eval():
function test_mezo2_sgd_causalLM_eval (line 121) | def test_mezo2_sgd_causalLM_eval():
FILE: test/mezo_sgd/hf_qwen3/test_memory.py
function train_mezo_sgd_causalLM (line 23) | def train_mezo_sgd_causalLM(model_config, zo_config, device='cuda:0'):
function train_mezo2_sgd_causalLM (line 40) | def train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda:0'):
function eval_mezo_sgd_causalLM (line 57) | def eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda:0'):
function eval_mezo2_sgd_causalLM (line 74) | def eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda:0'):
function test_mezo_sgd_causalLM_training (line 93) | def test_mezo_sgd_causalLM_training():
function test_mezo2_sgd_causalLM_training (line 104) | def test_mezo2_sgd_causalLM_training():
function test_mezo_sgd_causalLM_eval (line 115) | def test_mezo_sgd_causalLM_eval():
function test_mezo2_sgd_causalLM_eval (line 126) | def test_mezo2_sgd_causalLM_eval():
FILE: test/mezo_sgd/hf_qwen3/test_speed.py
function train_mezo_sgd_causalLM (line 21) | def train_mezo_sgd_causalLM(model_config, zo_config, device='cuda'):
function train_mezo2_sgd_causalLM (line 34) | def train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'):
function eval_mezo_sgd_causalLM (line 47) | def eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda'):
function eval_mezo2_sgd_causalLM (line 60) | def eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'):
function test_mezo_sgd_causalLM_training (line 74) | def test_mezo_sgd_causalLM_training():
function test_mezo2_sgd_causalLM_training (line 85) | def test_mezo2_sgd_causalLM_training():
function test_mezo_sgd_causalLM_eval (line 96) | def test_mezo_sgd_causalLM_eval():
function test_mezo2_sgd_causalLM_eval (line 107) | def test_mezo2_sgd_causalLM_eval():
FILE: test/mezo_sgd/hf_qwen3/utils.py
function get_args (line 13) | def get_args():
class Qwen3Configs (line 35) | class Qwen3Configs:
function model_size (line 52) | def model_size(model: torch.nn.Module):
function prepare_data_for_causalLM (line 58) | def prepare_data_for_causalLM(V, B, T, device='cuda'):
function check_peak_gpu_memory_usage (line 67) | def check_peak_gpu_memory_usage(iter, device=0, use_tqdm=False):
function check_and_update_peak_cpu_memory_usage (line 79) | def check_and_update_peak_cpu_memory_usage(iter, use_tqdm=False):
function reset_peak_cpu_memory_usage (line 90) | def reset_peak_cpu_memory_usage():
function check_throughput (line 97) | def check_throughput(iter, total_token_batch_size_per_iter, fn, *args, u...
FILE: test/mezo_sgd/nanogpt/test_acc.py
function train_mezo_sgd (line 16) | def train_mezo_sgd(model, args, model_config, device='cuda'):
function train_mezo2_sgd (line 28) | def train_mezo2_sgd(model, args, model_config, device='cuda'):
function eval_mezo_sgd (line 40) | def eval_mezo_sgd(model, args, model_config, device='cuda'):
function eval_mezo2_sgd (line 52) | def eval_mezo2_sgd(model, args, model_config, device='cuda'):
function test_mezo_sgd_training (line 64) | def test_mezo_sgd_training():
function test_mezo2_sgd_training (line 79) | def test_mezo2_sgd_training():
function test_mezo_sgd_eval (line 94) | def test_mezo_sgd_eval():
function test_mezo2_sgd_eval (line 109) | def test_mezo2_sgd_eval():
FILE: test/mezo_sgd/nanogpt/test_memory.py
function train_mezo_sgd (line 16) | def train_mezo_sgd(model, args, modelConfig, device='cuda:0'):
function train_mezo2_sgd (line 30) | def train_mezo2_sgd(model, args, modelConfig, device='cuda:0'):
function eval_mezo_sgd (line 44) | def eval_mezo_sgd(model, args, modelConfig, device='cuda:0'):
function eval_mezo2_sgd (line 58) | def eval_mezo2_sgd(model, args, modelConfig, device='cuda:0'):
function test_mezo_sgd_training (line 72) | def test_mezo_sgd_training():
function test_mezo2_sgd_training (line 87) | def test_mezo2_sgd_training():
function test_mezo_sgd_eval (line 102) | def test_mezo_sgd_eval():
function test_mezo2_sgd_eval (line 117) | def test_mezo2_sgd_eval():
FILE: test/mezo_sgd/nanogpt/test_speed.py
function train_mezo_sgd (line 16) | def train_mezo_sgd(model, args, modelConfig, device='cuda'):
function train_mezo2_sgd (line 26) | def train_mezo2_sgd(model, args, modelConfig, device='cuda'):
function eval_mezo_sgd (line 36) | def eval_mezo_sgd(model, args, modelConfig, device='cuda'):
function eval_mezo2_sgd (line 46) | def eval_mezo2_sgd(model, args, modelConfig, device='cuda'):
function test_mezo_sgd_training (line 56) | def test_mezo_sgd_training():
function test_mezo2_sgd_training (line 71) | def test_mezo2_sgd_training():
function test_mezo_sgd_eval (line 86) | def test_mezo_sgd_eval():
function test_mezo2_sgd_eval (line 101) | def test_mezo2_sgd_eval():
FILE: test/mezo_sgd/nanogpt/utils.py
function get_args (line 13) | def get_args():
function model_size (line 41) | def model_size(model: torch.nn.Module):
function prepare_data (line 47) | def prepare_data(V, B, T, device='cuda'):
function check_peak_gpu_memory_usage (line 57) | def check_peak_gpu_memory_usage(iter, device=0, use_tqdm=False):
function check_and_update_peak_cpu_memory_usage (line 69) | def check_and_update_peak_cpu_memory_usage(iter, use_tqdm=False):
function reset_peak_cpu_memory_usage (line 80) | def reset_peak_cpu_memory_usage():
function check_throughput (line 86) | def check_throughput(iter, total_token_batch_size_per_iter, fn, *args, u...
FILE: zo2/config/__init__.py
function ZOConfig (line 7) | def ZOConfig(method: str = "mezo-sgd", **kwargs):
FILE: zo2/config/mezo_sgd.py
class MeZOSGDConfig (line 8) | class MeZOSGDConfig:
FILE: zo2/model/base.py
class BaseZOModel (line 6) | class BaseZOModel(torch.nn.Module):
method __init__ (line 7) | def __init__(self):
method zo_train (line 17) | def zo_train(self):
method zo_eval (line 24) | def zo_eval(self):
method register_zo_train_loss_fn_pre_hook (line 31) | def register_zo_train_loss_fn_pre_hook(self, hook_fn):
method register_zo_train_loss_fn_post_hook (line 34) | def register_zo_train_loss_fn_post_hook(self, hook_fn):
method register_zo_eval_loss_fn_pre_hook (line 37) | def register_zo_eval_loss_fn_pre_hook(self, hook_fn):
method register_zo_eval_loss_fn_post_hook (line 40) | def register_zo_eval_loss_fn_post_hook(self, hook_fn):
method register_custom_opt (line 43) | def register_custom_opt(self, custom_opt_obj):
FILE: zo2/model/huggingface/opt/__init__.py
function get_opt_for_causalLM (line 8) | def get_opt_for_causalLM(zo_config):
function get_opt_for_sequence_classification (line 14) | def get_opt_for_sequence_classification(zo_config):
function get_opt_for_question_answering (line 20) | def get_opt_for_question_answering(zo_config):
FILE: zo2/model/huggingface/opt/mezo_sgd/__init__.py
function get_opt_for_causalLM_mezo_sgd (line 7) | def get_opt_for_causalLM_mezo_sgd(config: MeZOSGDConfig):
function get_opt_for_sequence_classification_mezo_sgd (line 10) | def get_opt_for_sequence_classification_mezo_sgd(config: MeZOSGDConfig):
function get_opt_for_question_answering_mezo_sgd (line 13) | def get_opt_for_question_answering_mezo_sgd(config: MeZOSGDConfig):
FILE: zo2/model/huggingface/opt/mezo_sgd/utils.py
function fn_get_opt_decoder_hidden_states_from_layer_outputs (line 6) | def fn_get_opt_decoder_hidden_states_from_layer_outputs(input):
function get_shift_logits (line 9) | def get_shift_logits(logits):
function get_shift_labels (line 12) | def get_shift_labels(labels):
function get_pooled_logits (line 15) | def get_pooled_logits(logits, batch_size, sequence_lengths):
function get_start_logits_and_end_logits (line 18) | def get_start_logits_and_end_logits(logits):
function get_qa_loss (line 24) | def get_qa_loss(loss_fct, start_logits, start_positions, end_logits, end...
function init_all_hidden_states (line 30) | def init_all_hidden_states(output_hidden_states):
function init_all_self_attns (line 33) | def init_all_self_attns(output_attentions):
function init_next_decoder_cache (line 36) | def init_next_decoder_cache(use_cache):
function update_next_decoder_cache (line 39) | def update_next_decoder_cache(use_cache, next_decoder_cache, layer_outpu...
function update_all_self_attns (line 44) | def update_all_self_attns(output_attentions, all_self_attns, layer_outpu...
function update_all_hidden_states (line 49) | def update_all_hidden_states(output_hidden_states, all_hidden_states, hi...
function get_past_key_value (line 54) | def get_past_key_value(past_key_values, idx):
function get_opt_sequence_classification_pooled_logits (line 57) | def get_opt_sequence_classification_pooled_logits(self, logits, input_id...
function get_opt_sequence_classification_loss (line 71) | def get_opt_sequence_classification_loss(self, loss, pooled_logits, labe...
function get_opt_question_answering_start_end_logits (line 93) | def get_opt_question_answering_start_end_logits(logits):
function get_opt_question_answering_loss (line 99) | def get_opt_question_answering_loss(total_loss, start_logits, start_posi...
FILE: zo2/model/huggingface/opt/mezo_sgd/zo.py
class OPTDecoder (line 42) | class OPTDecoder(modeling_opt.OPTDecoder, OPTPreTrainedModel):
method __init__ (line 50) | def __init__(self, config: OPTConfig):
class OPTModel (line 91) | class OPTModel(modeling_opt.OPTModel, OPTPreTrainedModel):
method __init__ (line 92) | def __init__(self, config: OPTConfig):
class OPTForCausalLM (line 99) | class OPTForCausalLM(modeling_opt.OPTForCausalLM, OPTPreTrainedModel, Ba...
method __init__ (line 102) | def __init__(self, config: OPTConfig):
method zo_init (line 113) | def zo_init(self, zo_config):
method forward (line 117) | def forward(
class OPTForSequenceClassification (line 221) | class OPTForSequenceClassification(modeling_opt.OPTForSequenceClassifica...
method __init__ (line 222) | def __init__(self, config: OPTConfig):
method zo_init (line 232) | def zo_init(self, zo_config):
method forward (line 243) | def forward(
class OPTForQuestionAnswering (line 274) | class OPTForQuestionAnswering(modeling_opt.OPTForQuestionAnswering, OPTP...
method __init__ (line 275) | def __init__(self, config: OPTConfig):
method zo_init (line 284) | def zo_init(self, zo_config):
method forward (line 289) | def forward(
class OptimizerOPTForCausalLM (line 358) | class OptimizerOPTForCausalLM(MeZOSGD):
method inner_zo_forward (line 361) | def inner_zo_forward(
method inner_zo_eval_forward (line 427) | def inner_zo_eval_forward(
class OptimizerOPTForSequenceClassification (line 478) | class OptimizerOPTForSequenceClassification(MeZOSGD):
method inner_zo_forward (line 481) | def inner_zo_forward(
method inner_zo_eval_forward (line 572) | def inner_zo_eval_forward(
class OptimizerOPTForQuestionAnswering (line 620) | class OptimizerOPTForQuestionAnswering(MeZOSGD):
method inner_zo_forward (line 623) | def inner_zo_forward(
method inner_zo_eval_forward (line 696) | def inner_zo_eval_forward(
FILE: zo2/model/huggingface/opt/mezo_sgd/zo2.py
class OPTDecoder (line 42) | class OPTDecoder(modeling_opt.OPTDecoder, OPTPreTrainedModel, BaseZOModel):
method __init__ (line 50) | def __init__(self, config: OPTConfig):
method zo_init (line 90) | def zo_init(self, zo_config):
method forward (line 94) | def forward(
class OPTModel (line 120) | class OPTModel(modeling_opt.OPTModel, OPTPreTrainedModel, BaseZOModel):
method __init__ (line 121) | def __init__(self, config: OPTConfig):
method zo_init (line 128) | def zo_init(self, zo_config):
method forward (line 140) | def forward(
class OPTForCausalLM (line 162) | class OPTForCausalLM(modeling_opt.OPTForCausalLM, OPTPreTrainedModel, Ba...
method __init__ (line 165) | def __init__(self, config: OPTConfig):
method zo_init (line 174) | def zo_init(self, zo_config):
method forward (line 180) | def forward(
class OPTForSequenceClassification (line 280) | class OPTForSequenceClassification(modeling_opt.OPTForSequenceClassifica...
method __init__ (line 281) | def __init__(self, config: OPTConfig):
method zo_init (line 291) | def zo_init(self, zo_config):
method forward (line 303) | def forward(
class OPTForQuestionAnswering (line 335) | class OPTForQuestionAnswering(modeling_opt.OPTForQuestionAnswering, OPTP...
method __init__ (line 336) | def __init__(self, config: OPTConfig):
method zo_init (line 345) | def zo_init(self, zo_config):
method forward (line 351) | def forward(
class OptimizerOPTDecoder (line 421) | class OptimizerOPTDecoder(MeZO2SGD):
method init_zo2 (line 423) | def init_zo2(self):
method init_zo2_upload (line 434) | def init_zo2_upload(self):
method inner_zo_forward (line 457) | def inner_zo_forward(
method inner_zo_eval_forward (line 655) | def inner_zo_eval_forward(
class OptimizerOPTModel (line 919) | class OptimizerOPTModel(MeZO2SGD):
method init_zo2 (line 921) | def init_zo2(self):
method init_zo2_upload (line 932) | def init_zo2_upload(self):
method inner_zo_forward (line 936) | def inner_zo_forward(
method inner_zo_eval_forward (line 974) | def inner_zo_eval_forward(
class OptimizerOPTForCausalLM (line 1020) | class OptimizerOPTForCausalLM(MeZO2SGD):
method init_zo2_upload (line 1022) | def init_zo2_upload(self):
method inner_zo_forward (line 1026) | def inner_zo_forward(
method inner_zo_eval_forward (line 1114) | def inner_zo_eval_forward(
class OptimizerOPTForSequenceClassification (line 1216) | class OptimizerOPTForSequenceClassification(MeZO2SGD):
method init_zo2_upload (line 1218) | def init_zo2_upload(self):
method inner_zo_forward (line 1222) | def inner_zo_forward(
method inner_zo_eval_forward (line 1347) | def inner_zo_eval_forward(
class OptimizerOPTForQuestionAnswering (line 1441) | class OptimizerOPTForQuestionAnswering(MeZO2SGD):
method init_zo2_upload (line 1443) | def init_zo2_upload(self):
method inner_zo_forward (line 1447) | def inner_zo_forward(
method inner_zo_eval_forward (line 1540) | def inner_zo_eval_forward(
FILE: zo2/model/huggingface/qwen3/__init__.py
function get_qwen3_for_causalLM (line 8) | def get_qwen3_for_causalLM(zo_config):
FILE: zo2/model/huggingface/qwen3/mezo_sgd/__init__.py
function get_qwen3_for_causalLM_mezo_sgd (line 7) | def get_qwen3_for_causalLM_mezo_sgd(config: MeZOSGDConfig):
FILE: zo2/model/huggingface/qwen3/mezo_sgd/utils.py
function fn_get_qwen3_decoder_hidden_states_from_layer_outputs (line 6) | def fn_get_qwen3_decoder_hidden_states_from_layer_outputs(input):
function fn_get_qwen3_sliced_logits_from_hidden_states (line 9) | def fn_get_qwen3_sliced_logits_from_hidden_states(hidden_states, slice_i...
FILE: zo2/model/huggingface/qwen3/mezo_sgd/zo.py
class Qwen3Model (line 38) | class Qwen3Model(modeling_qwen3.Qwen3Model, Qwen3PreTrainedModel):
method __init__ (line 46) | def __init__(self, config: Qwen3Config):
class Qwen3ForCausalLM (line 64) | class Qwen3ForCausalLM(modeling_qwen3.Qwen3ForCausalLM, Qwen3PreTrainedM...
method __init__ (line 69) | def __init__(self, config: Qwen3Config):
method zo_init (line 79) | def zo_init(self, zo_config):
method forward (line 86) | def forward(
class OptimizerQwen3ForCausalLM (line 148) | class OptimizerQwen3ForCausalLM(MeZOSGD):
method inner_zo_forward (line 151) | def inner_zo_forward(
method inner_zo_eval_forward (line 213) | def inner_zo_eval_forward(
FILE: zo2/model/huggingface/qwen3/mezo_sgd/zo2.py
class Qwen3Model (line 41) | class Qwen3Model(modeling_qwen3.Qwen3Model, Qwen3PreTrainedModel, BaseZO...
method __init__ (line 48) | def __init__(self, config: Qwen3Config):
method zo_init (line 68) | def zo_init(self, zo_config):
method forward (line 74) | def forward(
class Qwen3ForCausalLM (line 99) | class Qwen3ForCausalLM(modeling_qwen3.Qwen3ForCausalLM, Qwen3PreTrainedM...
method __init__ (line 104) | def __init__(self, config):
method zo_init (line 114) | def zo_init(self, zo_config):
method forward (line 123) | def forward(
class OptimizerQwen3Model (line 183) | class OptimizerQwen3Model(MeZO2SGD):
method init_zo2 (line 185) | def init_zo2(self):
method init_zo2_upload (line 196) | def init_zo2_upload(self):
method inner_zo_forward (line 214) | def inner_zo_forward(
method inner_zo_eval_forward (line 365) | def inner_zo_eval_forward(
class OptimizerQwen3ForCausalLM (line 509) | class OptimizerQwen3ForCausalLM(MeZO2SGD):
method init_zo2_upload (line 511) | def init_zo2_upload(self):
method inner_zo_forward (line 515) | def inner_zo_forward(
method inner_zo_eval_forward (line 594) | def inner_zo_eval_forward(
FILE: zo2/model/huggingface/zo_init.py
function zo_hf_init (line 25) | def zo_hf_init(zo_config):
function main (line 37) | def main():
FILE: zo2/model/nanogpt/__init__.py
function get_nanogpt (line 8) | def get_nanogpt(zo_config):
FILE: zo2/model/nanogpt/mezo_sgd/__init__.py
function get_nanogpt_mezo_sgd (line 9) | def get_nanogpt_mezo_sgd(config: MeZOSGDConfig):
FILE: zo2/model/nanogpt/mezo_sgd/zo.py
class GPT (line 13) | class GPT(model.GPT, BaseZOModel):
method __init__ (line 14) | def __init__(self, config: model.GPTConfig, zo_config: MeZOSGDConfig):
method forward (line 18) | def forward(self, idx, pos, targets=None):
class Optimizer (line 26) | class Optimizer(MeZOSGD):
method inner_zo_forward (line 29) | def inner_zo_forward(self, idx, pos, targets):
method inner_zo_eval_forward (line 44) | def inner_zo_eval_forward(self, eval_fn, idx, pos, targets):
FILE: zo2/model/nanogpt/mezo_sgd/zo2.py
class GPT (line 14) | class GPT(model.GPT, BaseZOModel):
method __init__ (line 15) | def __init__(self, config: model.GPTConfig, zo_config: MeZOSGDConfig):
method forward (line 19) | def forward(self, idx, pos, targets=None):
class Optimizer (line 27) | class Optimizer(MeZO2SGD):
method init_zo2_upload (line 29) | def init_zo2_upload(self):
method inner_zo_forward (line 50) | def inner_zo_forward(self, idx, pos, targets):
method inner_zo_eval_forward (line 114) | def inner_zo_eval_forward(self, eval_fn, idx, pos, targets):
FILE: zo2/model/nanogpt/model.py
class GPTConfig (line 19) | class GPTConfig:
class GPTConfigs (line 28) | class GPTConfigs:
class LayerNorm (line 43) | class LayerNorm(nn.Module):
method __init__ (line 46) | def __init__(self, ndim, bias):
method forward (line 51) | def forward(self, input):
class CausalSelfAttention (line 54) | class CausalSelfAttention(nn.Module):
method __init__ (line 56) | def __init__(self, config):
method forward (line 77) | def forward(self, x):
class MLP (line 103) | class MLP(nn.Module):
method __init__ (line 105) | def __init__(self, config):
method forward (line 112) | def forward(self, x):
class Block (line 119) | class Block(nn.Module):
method __init__ (line 121) | def __init__(self, config):
method forward (line 128) | def forward(self, x):
class GPT (line 134) | class GPT(nn.Module):
method __init__ (line 136) | def __init__(self, config):
method get_num_params (line 166) | def get_num_params(self, non_embedding=True):
method _init_weights (line 178) | def _init_weights(self, module):
method forward (line 186) | def forward(self, idx, pos, targets=None):
method crop_block_size (line 205) | def crop_block_size(self, block_size):
method from_pretrained (line 217) | def from_pretrained(cls, model_type, override_args=None):
method configure_optimizers (line 273) | def configure_optimizers(self, weight_decay, learning_rate, betas, dev...
method estimate_mfu (line 299) | def estimate_mfu(self, fwdbwd_per_iter, dt):
method generate (line 315) | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
FILE: zo2/optimizer/base.py
class BaseOptimizer (line 7) | class BaseOptimizer(Optimizer):
method __init__ (line 12) | def __init__(self, params, defaults):
method _update_lr (line 25) | def _update_lr(self):
method _set_lr (line 28) | def _set_lr(self):
FILE: zo2/optimizer/mezo_sgd/utils/comm.py
function module_to_bucket_inplace (line 9) | def module_to_bucket_inplace(module: nn.Module):
function bucket_to_module_inplace (line 13) | def bucket_to_module_inplace(bucket: torch.Tensor, module: nn.Module):
function create_disk_offload_path (line 23) | def create_disk_offload_path(path, module_id):
function get_disk_offload_path (line 35) | def get_disk_offload_path(path, module_id):
function clear_disk_offload_path (line 38) | def clear_disk_offload_path(path, module_id):
function set_nested_attr (line 46) | def set_nested_attr(obj, attr, value):
FILE: zo2/optimizer/mezo_sgd/zo.py
class MeZOSGD (line 16) | class MeZOSGD(BaseOptimizer):
method __init__ (line 21) | def __init__(self, model: nn.Module, config: MeZOSGDConfig):
method zo_perturb_parameters (line 47) | def zo_perturb_parameters(self, module: nn.Module, scaling_factor: flo...
method zo_update (line 65) | def zo_update(self, module, weight_decay=None):
method zo_perturb_shifts (line 90) | def zo_perturb_shifts(self, first_perturb_shift=1, stride=2):
method compute_grad (line 99) | def compute_grad(self, loss1, loss2):
method zo_forward (line 103) | def zo_forward(self, *args, zo_random_seed: int=None, **kwargs):
method zo_eval_forward (line 129) | def zo_eval_forward(self, *args, **kwargs):
method inner_zo_forward (line 139) | def inner_zo_forward(self, idx, pos, targets):
method inner_zo_eval_forward (line 159) | def inner_zo_eval_forward(self, eval_fn, idx, pos, targets):
FILE: zo2/optimizer/mezo_sgd/zo2.py
class MeZO2SGD (line 17) | class MeZO2SGD(MeZOSGD):
method __init__ (line 26) | def __init__(self, model, config: MeZOSGDConfig):
method init_zo2 (line 52) | def init_zo2(self):
method init_zo2_amp (line 68) | def init_zo2_amp(self):
method assign_zo2_attributes (line 84) | def assign_zo2_attributes(self, source, target):
method zo_update (line 100) | def zo_update(self, module, weight_decay=None):
method module_dual_forward (line 116) | def module_dual_forward(self, module, inputs1, inputs2, projected_grad...
method function_dual_forward (line 146) | def function_dual_forward(self, fn, inputs1, inputs2):
method zo_forward (line 164) | def zo_forward(self, *args, seed: int=None, **kwargs):
method task_upload (line 188) | def task_upload(self, module, device='cuda', upload_sync=False, *args,...
method task_offload (line 212) | def task_offload(self, module, device='cpu', offload_sync=False, *args...
method task_compute_module (line 237) | def task_compute_module(self, module, inputs1, inputs2, grad, compute_...
method task_compute_function (line 297) | def task_compute_function(self, fn, inputs1, inputs2, compute_sync=Fal...
method zo_eval_forward (line 355) | def zo_eval_forward(self, *args, **kwargs):
method add_zo2_eval_comm_hooks (line 371) | def add_zo2_eval_comm_hooks(self, blocks):
method clear_zo2_eval_comm_hooks (line 391) | def clear_zo2_eval_comm_hooks(self, handles):
method eval_upload_hook (line 401) | def eval_upload_hook(self, module, input):
method eval_offload_hook (line 416) | def eval_offload_hook(self, module, input, output):
method upload_impl (line 442) | def upload_impl(
method offload_impl (line 484) | def offload_impl(
method compute_module_impl (line 525) | def compute_module_impl(
method compute_function_impl (line 551) | def compute_function_impl(
method amp_decompress_impl (line 577) | def amp_decompress_impl(self, module: nn.Module) -> nn.Module:
method amp_compress_impl (line 597) | def amp_compress_impl(self, module: nn.Module) -> nn.Module:
method init_zo2_upload (line 618) | def init_zo2_upload(self):
method inner_zo_forward (line 646) | def inner_zo_forward(self, idx, pos, targets):
method inner_zo_eval_forward (line 721) | def inner_zo_eval_forward(self, eval_fn, idx, pos, targets):
FILE: zo2/trainer/hf_transformers/trainer.py
function _is_peft_model (line 259) | def _is_peft_model(model):
function _get_fsdp_ckpt_kwargs (line 271) | def _get_fsdp_ckpt_kwargs():
function safe_globals (line 279) | def safe_globals():
class ZOTrainer (line 317) | class ZOTrainer(Trainer):
method __init__ (line 322) | def __init__(
method _inner_training_loop (line 359) | def _inner_training_loop(
method _load_optimizer_and_scheduler (line 875) | def _load_optimizer_and_scheduler(self, checkpoint, model=None):
method create_optimizer_and_scheduler (line 885) | def create_optimizer_and_scheduler(self, num_training_steps: int, mode...
method _move_model_to_device (line 903) | def _move_model_to_device(self, model, device):
method _zo2_unsupported_conditions (line 908) | def _zo2_unsupported_conditions(self, args):
method register_zo2_training_step_pre_hook (line 920) | def register_zo2_training_step_pre_hook(self, hook_fn):
method register_zo2_training_step_post_hook (line 931) | def register_zo2_training_step_post_hook(self, hook_fn):
method zo2_training_step (line 952) | def zo2_training_step(self, model: nn.Module, inputs: dict[str, Union[...
FILE: zo2/trainer/hf_trl/sft_trainer.py
class ZOSFTTrainer (line 226) | class ZOSFTTrainer(SFTTrainer):
method __init__ (line 228) | def __init__(
method _inner_training_loop (line 278) | def _inner_training_loop(self, batch_size=None, args=None, resume_from...
method train (line 713) | def train(self, *args, **kwargs):
method _load_optimizer_and_scheduler (line 738) | def _load_optimizer_and_scheduler(self, checkpoint, model=None):
method create_optimizer_and_scheduler (line 748) | def create_optimizer_and_scheduler(self, num_training_steps: int, mode...
method _move_model_to_device (line 766) | def _move_model_to_device(self, model, device):
method _zo2_unsupported_conditions (line 771) | def _zo2_unsupported_conditions(self, args):
method register_zo2_training_step_pre_hook (line 787) | def register_zo2_training_step_pre_hook(self, hook_fn):
method register_zo2_training_step_post_hook (line 790) | def register_zo2_training_step_post_hook(self, hook_fn):
method zo2_training_step (line 793) | def zo2_training_step(self, model: nn.Module, inputs: Dict[str, Union[...
FILE: zo2/utils/utils.py
function print_all (line 10) | def print_all(module: nn.Module, inputs, outputs):
function print_hook (line 29) | def print_hook(module, input, output):
function print_para_and_device (line 33) | def print_para_and_device(model):
function cal_self_reg_loss (line 37) | def cal_self_reg_loss(logits, labels):
function seed_everything (line 44) | def seed_everything(seed):
Condensed preview — 101 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (664K chars).
[
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 7765,
"preview": "# ZO2 (Zeroth-Order Offloading): Full Parameter Fine-Tuning 175B LLMs with 18GB GPU Memory\n\n[ 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport sys\nsys.path.append(\"../zo2\")\n"
},
{
"path": "example/mezo_runner/README.md",
"chars": 1546,
"preview": "# Example: Apply MeZO on LLMs\n\nModified from [MeZO](https://github.com/princeton-nlp/MeZO/blob/main/large_models/README."
},
{
"path": "example/mezo_runner/metrics.py",
"chars": 3304,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport numpy as np\nimport collections"
},
{
"path": "example/mezo_runner/mezo.sh",
"chars": 1915,
"preview": "MODEL=${MODEL:-facebook/opt-1.3b}\nMODEL_NAME=(${MODEL//\\// })\nMODEL_NAME=\"${MODEL_NAME[-1]}\"\n\nBS=${BS:-16}\nLR=${LR:-1e-5"
},
{
"path": "example/mezo_runner/run.py",
"chars": 27996,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n\"\"\"\nModified from https://github.com/"
},
{
"path": "example/mezo_runner/tasks.py",
"chars": 14373,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n\"\"\"\nCopied https://github.com/princet"
},
{
"path": "example/mezo_runner/templates.py",
"chars": 13027,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n\"\"\"\nCopied https://github.com/princet"
},
{
"path": "example/mezo_runner/utils.py",
"chars": 16287,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n\"\"\"\nModified from https://github.com/"
},
{
"path": "requirements.txt",
"chars": 831,
"preview": "brotli==1.0.9\ncertifi==2024.7.4\ncharset-normalizer==3.3.2\nfilelock==3.13.1\nidna==3.7\nJinja2==3.1.4\nMarkupSafe==2.1.3\nnum"
},
{
"path": "script/add-copyright.py",
"chars": 1451,
"preview": "import os\nimport datetime\nimport logging\n\ncurrent_year = datetime.datetime.now().year\nowner = \"liangyuwang\"\nlogging.basi"
},
{
"path": "script/clear-pycache.sh",
"chars": 63,
"preview": "find . | grep -E \"(/__pycache__$|\\.pyc$|\\.pyo$)\" | xargs rm -rf"
},
{
"path": "setup.py",
"chars": 893,
"preview": "from setuptools import setup, find_packages\n\nwith open('requirements.txt') as f:\n requirements = f.read().splitlines("
},
{
"path": "test/README.md",
"chars": 473,
"preview": "# Test\n\n- Important Notice: For fine-tuning the **OPT-175B** model, ensure that your system is equipped with at least `1"
},
{
"path": "test/mezo_sgd/hf_gpt/trainer.py",
"chars": 83,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n"
},
{
"path": "test/mezo_sgd/hf_llama/trainer.py",
"chars": 83,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n"
},
{
"path": "test/mezo_sgd/hf_opt/record_zo2_memory.sh",
"chars": 1263,
"preview": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_"
},
{
"path": "test/mezo_sgd/hf_opt/record_zo2_speed.sh",
"chars": 1416,
"preview": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_"
},
{
"path": "test/mezo_sgd/hf_opt/test_acc.py",
"chars": 17646,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport sys\nsys.path.append(\"../zo2\")\n"
},
{
"path": "test/mezo_sgd/hf_opt/test_acc_eval.sh",
"chars": 1673,
"preview": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_"
},
{
"path": "test/mezo_sgd/hf_opt/test_acc_train.sh",
"chars": 1866,
"preview": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_"
},
{
"path": "test/mezo_sgd/hf_opt/test_memory.py",
"chars": 18969,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport sys\nsys.path.append(\"../zo2\")\n"
},
{
"path": "test/mezo_sgd/hf_opt/test_memory_eval.sh",
"chars": 1637,
"preview": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_"
},
{
"path": "test/mezo_sgd/hf_opt/test_memory_train.sh",
"chars": 1623,
"preview": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_"
},
{
"path": "test/mezo_sgd/hf_opt/test_scheduler_acc_eval.sh",
"chars": 1687,
"preview": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_"
},
{
"path": "test/mezo_sgd/hf_opt/test_scheduler_acc_train.sh",
"chars": 1923,
"preview": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_"
},
{
"path": "test/mezo_sgd/hf_opt/test_speed.py",
"chars": 17620,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport sys\nsys.path.append(\"../zo2\")\n"
},
{
"path": "test/mezo_sgd/hf_opt/test_speed_eval.sh",
"chars": 2225,
"preview": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_"
},
{
"path": "test/mezo_sgd/hf_opt/test_speed_train.sh",
"chars": 2211,
"preview": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_"
},
{
"path": "test/mezo_sgd/hf_opt/utils.py",
"chars": 5487,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\nimport time\nimport argpa"
},
{
"path": "test/mezo_sgd/hf_qwen3/record_zo2_memory.sh",
"chars": 1225,
"preview": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"qwen3_0_6b\" \"qwen3_1_7b\" \"qwen3_4b\" \"qwen3_8b\" \"qwen3_14b\" \"qwen3_32b"
},
{
"path": "test/mezo_sgd/hf_qwen3/record_zo2_speed.sh",
"chars": 1378,
"preview": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"qwen3_0_6b\" \"qwen3_1_7b\" \"qwen3_4b\" \"qwen3_8b\" \"qwen3_14b\" \"qwen3_32b"
},
{
"path": "test/mezo_sgd/hf_qwen3/test_acc.py",
"chars": 6186,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport sys\nsys.path.append(\"../zo2\")\n"
},
{
"path": "test/mezo_sgd/hf_qwen3/test_acc_eval.sh",
"chars": 1574,
"preview": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"qwen3_0_6b\" \"qwen3_1_7b\" \"qwen3_4b\" \"qwen3_8b\" \"qwen3_14b\" \"qwen3_32b"
},
{
"path": "test/mezo_sgd/hf_qwen3/test_acc_train.sh",
"chars": 1748,
"preview": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"qwen3_0_6b\" \"qwen3_1_7b\" \"qwen3_4b\" \"qwen3_8b\" \"qwen3_14b\" \"qwen3_32b"
},
{
"path": "test/mezo_sgd/hf_qwen3/test_memory.py",
"chars": 6304,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport sys\nsys.path.append(\"../zo2\")\n"
},
{
"path": "test/mezo_sgd/hf_qwen3/test_memory_train.sh",
"chars": 1571,
"preview": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"qwen3_0_6b\" \"qwen3_1_7b\" \"qwen3_4b\" \"qwen3_8b\" \"qwen3_14b\" \"qwen3_32b"
},
{
"path": "test/mezo_sgd/hf_qwen3/test_speed.py",
"chars": 5793,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport sys\nsys.path.append(\"../zo2\")\n"
},
{
"path": "test/mezo_sgd/hf_qwen3/test_speed_train.sh",
"chars": 2159,
"preview": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"qwen3_0_6b\" \"qwen3_1_7b\" \"qwen3_4b\" \"qwen3_8b\" \"qwen3_14b\" \"qwen3_32b"
},
{
"path": "test/mezo_sgd/hf_qwen3/utils.py",
"chars": 5023,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\nimport time\nimport argpa"
},
{
"path": "test/mezo_sgd/nanogpt/record_zo2_memory.sh",
"chars": 1090,
"preview": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_ids=(\"gpt2\" \"gpt2_medium\" \"gpt2_large\" \"gpt2_xl\" \"opt_125m\" \"opt_350m\" \"opt_1"
},
{
"path": "test/mezo_sgd/nanogpt/record_zo2_speed.sh",
"chars": 1235,
"preview": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_ids=(\"gpt2\" \"gpt2_medium\" \"gpt2_large\" \"gpt2_xl\" \"opt_125m\" \"opt_350m\" \"opt_1"
},
{
"path": "test/mezo_sgd/nanogpt/test_acc.py",
"chars": 5566,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport sys\nsys.path.append(\"../zo2\")\n"
},
{
"path": "test/mezo_sgd/nanogpt/test_acc_eval.sh",
"chars": 1244,
"preview": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_ids=(\"gpt2\" \"gpt2_medium\" \"gpt2_large\" \"gpt2_xl\" \"opt_125m\" \"opt_350m\" \"opt_1"
},
{
"path": "test/mezo_sgd/nanogpt/test_acc_train.sh",
"chars": 1455,
"preview": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_ids=(\"gpt2\" \"gpt2_medium\" \"gpt2_large\" \"gpt2_xl\" \"opt_125m\" \"opt_350m\" \"opt_1"
},
{
"path": "test/mezo_sgd/nanogpt/test_memory.py",
"chars": 5984,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport sys\nsys.path.append(\"../zo2\")\n"
},
{
"path": "test/mezo_sgd/nanogpt/test_memory_eval.sh",
"chars": 1410,
"preview": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_ids=(\"gpt2\" \"gpt2_medium\" \"gpt2_large\" \"gpt2_xl\" \"opt_125m\" \"opt_350m\" \"opt_1"
},
{
"path": "test/mezo_sgd/nanogpt/test_memory_train.sh",
"chars": 1396,
"preview": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_ids=(\"gpt2\" \"gpt2_medium\" \"gpt2_large\" \"gpt2_xl\" \"opt_125m\" \"opt_350m\" \"opt_1"
},
{
"path": "test/mezo_sgd/nanogpt/test_speed.py",
"chars": 5428,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport sys\nsys.path.append(\"../zo2\")\n"
},
{
"path": "test/mezo_sgd/nanogpt/test_speed_eval.sh",
"chars": 1974,
"preview": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_ids=(\"gpt2\" \"gpt2_medium\" \"gpt2_large\" \"gpt2_xl\" \"opt_125m\" \"opt_350m\" \"opt_1"
},
{
"path": "test/mezo_sgd/nanogpt/test_speed_train.sh",
"chars": 1960,
"preview": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_ids=(\"gpt2\" \"gpt2_medium\" \"gpt2_large\" \"gpt2_xl\" \"opt_125m\" \"opt_350m\" \"opt_1"
},
{
"path": "test/mezo_sgd/nanogpt/utils.py",
"chars": 3540,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\nimport time\nimport argpa"
},
{
"path": "tutorial/README.md",
"chars": 556,
"preview": "# API of ZO2\n\nWelcome to the ZO2 API documentation!\n\n## Standard Usage\n\n### 1. Quick Start\n\nFor a straightforward introd"
},
{
"path": "tutorial/colab.ipynb",
"chars": 7577,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"## Environment Setting\"\n ]\n },\n "
},
{
"path": "tutorial/demo.ipynb",
"chars": 10203,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"## Fine-tune HF Model with Your Cus"
},
{
"path": "tutorial/huggingface.ipynb",
"chars": 7095,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"## Environment Setting\"\n ]\n },\n "
},
{
"path": "tutorial/nanogpt.ipynb",
"chars": 31560,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"## Environment Setting\"\n ]\n },\n "
},
{
"path": "zo2/README.md",
"chars": 556,
"preview": "# Core code of ZO2\n\n## Features\n\n1. Fuse model dual-forward and optimizer step into model forward code. For example,\n\n``"
},
{
"path": "zo2/__init__.py",
"chars": 379,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n# configs\nfrom .config import ZOConfi"
},
{
"path": "zo2/config/__init__.py",
"chars": 412,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nfrom .mezo_sgd import MeZOSGDConfig\n\n"
},
{
"path": "zo2/config/mezo_sgd.py",
"chars": 1683,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\nfrom dataclasses import "
},
{
"path": "zo2/model/__init__.py",
"chars": 83,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n"
},
{
"path": "zo2/model/base.py",
"chars": 1463,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\n\nclass BaseZOModel(torch"
},
{
"path": "zo2/model/huggingface/__init__.py",
"chars": 83,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n"
},
{
"path": "zo2/model/huggingface/gpt/mezo_sgd/zo.py",
"chars": 83,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n"
},
{
"path": "zo2/model/huggingface/gpt/mezo_sgd/zo2.py",
"chars": 83,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n"
},
{
"path": "zo2/model/huggingface/llama/mezo_sgd/zo.py",
"chars": 237,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\nimport torch.nn as nn\nfr"
},
{
"path": "zo2/model/huggingface/llama/mezo_sgd/zo2.py",
"chars": 237,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\nimport torch.nn as nn\nfr"
},
{
"path": "zo2/model/huggingface/opt/__init__.py",
"chars": 762,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nfrom . import (\n mezo_sgd,\n)\n\ndef "
},
{
"path": "zo2/model/huggingface/opt/mezo_sgd/__init__.py",
"chars": 601,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nfrom . import zo, zo2\nfrom .....confi"
},
{
"path": "zo2/model/huggingface/opt/mezo_sgd/utils.py",
"chars": 4746,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\n\ndef fn_get_opt_decoder_"
},
{
"path": "zo2/model/huggingface/opt/mezo_sgd/zo.py",
"chars": 35426,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\nimport torch.nn as nn\nfr"
},
{
"path": "zo2/model/huggingface/opt/mezo_sgd/zo2.py",
"chars": 81270,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport random\nimport torch\nimport tor"
},
{
"path": "zo2/model/huggingface/qwen3/__init__.py",
"chars": 318,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nfrom . import (\n mezo_sgd,\n)\n\ndef "
},
{
"path": "zo2/model/huggingface/qwen3/mezo_sgd/__init__.py",
"chars": 284,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nfrom . import zo, zo2\nfrom .....confi"
},
{
"path": "zo2/model/huggingface/qwen3/mezo_sgd/utils.py",
"chars": 310,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\n\ndef fn_get_qwen3_decode"
},
{
"path": "zo2/model/huggingface/qwen3/mezo_sgd/zo.py",
"chars": 11400,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\nimport torch.nn as nn\nfr"
},
{
"path": "zo2/model/huggingface/qwen3/mezo_sgd/zo2.py",
"chars": 31473,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport random\nimport torch\nimport tor"
},
{
"path": "zo2/model/huggingface/zo_init.py",
"chars": 1316,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nfrom contextlib import contextmanager"
},
{
"path": "zo2/model/nanogpt/__init__.py",
"chars": 296,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nfrom . import (\n mezo_sgd,\n)\n\ndef "
},
{
"path": "zo2/model/nanogpt/mezo_sgd/__init__.py",
"chars": 352,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nfrom ..model import GPTConfig, GPTCon"
},
{
"path": "zo2/model/nanogpt/mezo_sgd/zo.py",
"chars": 1451,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\nimport torch.nn.function"
},
{
"path": "zo2/model/nanogpt/mezo_sgd/zo2.py",
"chars": 5977,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\nimport torch.nn.function"
},
{
"path": "zo2/model/nanogpt/model.py",
"chars": 16449,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n\"\"\"\nModified from https://github.com/"
},
{
"path": "zo2/optimizer/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "zo2/optimizer/base.py",
"chars": 1025,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\nfrom torch.optim.optimiz"
},
{
"path": "zo2/optimizer/mezo_sgd/__init__.py",
"chars": 83,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n"
},
{
"path": "zo2/optimizer/mezo_sgd/utils/__init__.py",
"chars": 121,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nfrom .com import *\nfrom .comm import "
},
{
"path": "zo2/optimizer/mezo_sgd/utils/com.py",
"chars": 83,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n"
},
{
"path": "zo2/optimizer/mezo_sgd/utils/comm.py",
"chars": 1547,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport os\nimport torch\nimport torch.n"
},
{
"path": "zo2/optimizer/mezo_sgd/zo.py",
"chars": 6585,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport sys\nsys.path.append('./zo2')\n\n"
},
{
"path": "zo2/optimizer/mezo_sgd/zo2.py",
"chars": 33373,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport sys\nsys.path.append('./zo2')\ni"
},
{
"path": "zo2/trainer/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "zo2/trainer/hf_transformers/__init__.py",
"chars": 30,
"preview": "from .trainer import ZOTrainer"
},
{
"path": "zo2/trainer/hf_transformers/trainer.py",
"chars": 42539,
"preview": "# Copyright 2020-present the HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n"
},
{
"path": "zo2/trainer/hf_trl/__init__.py",
"chars": 37,
"preview": "from .sft_trainer import ZOSFTTrainer"
},
{
"path": "zo2/trainer/hf_trl/sft_trainer.py",
"chars": 35661,
"preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "zo2/utils/__init__.py",
"chars": 117,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nfrom .utils import seed_everything"
},
{
"path": "zo2/utils/utils.py",
"chars": 1622,
"preview": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nfrom torch import nn\nimport torch\nimp"
}
]
About this extraction
This page contains the full source code of the liangyuwang/zo2 GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 101 files (619.2 KB), approximately 153.3k tokens, and a symbol index with 574 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.