Repository: liangyuwang/zo2 Branch: main Commit: 4bca25c2cd69 Files: 101 Total size: 619.2 KB Directory structure: gitextract_z35rb3tr/ ├── LICENSE ├── README.md ├── env.yml ├── example/ │ ├── demo/ │ │ └── train_zo2_with_hf_opt.py │ └── mezo_runner/ │ ├── README.md │ ├── metrics.py │ ├── mezo.sh │ ├── run.py │ ├── tasks.py │ ├── templates.py │ └── utils.py ├── requirements.txt ├── script/ │ ├── add-copyright.py │ └── clear-pycache.sh ├── setup.py ├── test/ │ ├── README.md │ └── mezo_sgd/ │ ├── hf_gpt/ │ │ └── trainer.py │ ├── hf_llama/ │ │ └── trainer.py │ ├── hf_opt/ │ │ ├── record_zo2_memory.sh │ │ ├── record_zo2_speed.sh │ │ ├── test_acc.py │ │ ├── test_acc_eval.sh │ │ ├── test_acc_train.sh │ │ ├── test_memory.py │ │ ├── test_memory_eval.sh │ │ ├── test_memory_train.sh │ │ ├── test_scheduler_acc_eval.sh │ │ ├── test_scheduler_acc_train.sh │ │ ├── test_speed.py │ │ ├── test_speed_eval.sh │ │ ├── test_speed_train.sh │ │ └── utils.py │ ├── hf_qwen3/ │ │ ├── record_zo2_memory.sh │ │ ├── record_zo2_speed.sh │ │ ├── test_acc.py │ │ ├── test_acc_eval.sh │ │ ├── test_acc_train.sh │ │ ├── test_memory.py │ │ ├── test_memory_train.sh │ │ ├── test_speed.py │ │ ├── test_speed_train.sh │ │ └── utils.py │ └── nanogpt/ │ ├── record_zo2_memory.sh │ ├── record_zo2_speed.sh │ ├── test_acc.py │ ├── test_acc_eval.sh │ ├── test_acc_train.sh │ ├── test_memory.py │ ├── test_memory_eval.sh │ ├── test_memory_train.sh │ ├── test_speed.py │ ├── test_speed_eval.sh │ ├── test_speed_train.sh │ └── utils.py ├── tutorial/ │ ├── README.md │ ├── colab.ipynb │ ├── demo.ipynb │ ├── huggingface.ipynb │ └── nanogpt.ipynb └── zo2/ ├── README.md ├── __init__.py ├── config/ │ ├── __init__.py │ └── mezo_sgd.py ├── model/ │ ├── __init__.py │ ├── base.py │ ├── huggingface/ │ │ ├── __init__.py │ │ ├── gpt/ │ │ │ └── mezo_sgd/ │ │ │ ├── zo.py │ │ │ └── zo2.py │ │ ├── llama/ │ │ │ └── mezo_sgd/ │ │ │ ├── zo.py │ │ │ └── zo2.py │ │ ├── opt/ │ │ │ ├── __init__.py │ │ │ └── mezo_sgd/ │ │ │ ├── __init__.py │ │ │ ├── utils.py │ │ │ ├── zo.py │ │ │ └── zo2.py │ │ ├── qwen3/ │ │ │ ├── __init__.py │ │ │ └── mezo_sgd/ │ │ │ ├── __init__.py │ │ │ ├── utils.py │ │ │ ├── zo.py │ │ │ └── zo2.py │ │ └── zo_init.py │ └── nanogpt/ │ ├── __init__.py │ ├── mezo_sgd/ │ │ ├── __init__.py │ │ ├── zo.py │ │ └── zo2.py │ └── model.py ├── optimizer/ │ ├── __init__.py │ ├── base.py │ └── mezo_sgd/ │ ├── __init__.py │ ├── utils/ │ │ ├── __init__.py │ │ ├── com.py │ │ └── comm.py │ ├── zo.py │ └── zo2.py ├── trainer/ │ ├── __init__.py │ ├── hf_transformers/ │ │ ├── __init__.py │ │ └── trainer.py │ └── hf_trl/ │ ├── __init__.py │ └── sft_trainer.py └── utils/ ├── __init__.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # ZO2 (Zeroth-Order Offloading): Full Parameter Fine-Tuning 175B LLMs with 18GB GPU Memory [![arXiv](https://img.shields.io/badge/Arxiv-2503.12668-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2503.12668) [![License](https://img.shields.io/badge/License-Apache%202.0-yellow)](https://github.com/liangyuwang/zo2/blob/main/LICENSE) [![GitDiagran](https://img.shields.io/badge/Git-Diagram%20-blue)](https://gitdiagram.com/liangyuwang/zo2) [![DeepWiki](https://img.shields.io/badge/Devin-DeepWiki%20-green)](https://deepwiki.com/liangyuwang/zo2) 👋 Welcome! **ZO2** is an innovative framework specifically designed to enhance the fine-tuning of large language models (LLMs) using **zeroth-order (ZO)** optimization techniques and advanced **offloading** technologies. This framework is particularly tailored for setups with limited GPU memory (e.g. fine-tune **[OPT-175B](https://arxiv.org/abs/2205.01068)** with just **18GB GPU memory**), enabling the fine-tuning of models that were previously unmanageable due to hardware constraints. - The table below displays the GPU memory usage for various OPT model sizes when fine-tuned using the ZO2 framework: | OPT Models | 1.3B | 2.7B | 6.7B | 13B | 30B | 66B | 175B | | :-----------------------: | :------: | :------: | :------: | :------: | :------: | :-------: | :-----------------: | | **GPU memory (GB)** | `3.75` | `4.14` | `4.99` | `6.18` | `8.86` | `12.07` | **`18.04`** | - [Install](#️installation) the package and execute the following test to see the memory usage: ```shell bash test/mezo_sgd/hf_opt/record_zo2_memory.sh ``` ## 📰 News - 16/07/2025: ZO2 is accepted by [COLM](https://colmweb.org/index.html). - 02/05/2025: Added support for [Qwen3](https://qwenlm.github.io/blog/qwen3/). You can now fully fine-tune the [32B version](https://huggingface.co/Qwen/Qwen3-32B-FP8) with just 6GB GPU memory using ZO2. Please refer to our [example](example/mezo_runner/). - 01/05/2025: We upgraded the environment and dependencies to align with the latest `transformers==4.51.3`. - 06/03/2025: We have open-sourced ZO2! ## 💡 Key Features - **Optimized ZO CPU Offloading**: ZO2 leverages `zeroth-order (ZO)` methods to efficiently use `CPU offloading`, avoiding redundant data transfers and significantly reducing GPU memory demands. This allows for handling large-scale models on hardware with limited GPU resources. - **Dynamic Scheduling**: Incorporates a high-performance scheduler to optimize the `computation-communication overlap`, enhancing GPU utilization and preventing training delays. - **Capability for Very Large Models**: Enables the fine-tuning of extraordinarily large models, such as those with over `175 billion parameters`, on single GPUs with as little as `18GB` of memory, previously impossible with traditional methods. - **Empirical Validation**: ZO2 has demonstrated through rigorous testing that it can efficiently fine-tune massive models `without extra time costs or accuracy losses`, confirming its effectiveness for large-scale model training. ## ⚙️ Installation We offer two installation options, and you only need to use one of them to install ZO2: 1. To experiment with our examples, tutorials, or tests, follow these steps to set up the ZO2 environment: ```shell git clone https://github.com/liangyuwang/zo2.git cd zo2/ conda env create -f env.yml conda activate zo2 ``` 2. If you want to use ZO2 as a package in your own code, you can install it directly in your Python environment. Before installing the ZO2 package, ensure you have the required dependencies: - [PyTorch](https://pytorch.org/get-started/locally/) >= 2.4.0, CUDA >= 12.1 Once the dependencies are installed, you can install the ZO2 package using pip: ```shell pip install git+https://github.com/liangyuwang/zo2.git ``` ## 🛠️ Usage We utilize the [OPT](https://arxiv.org/abs/2205.01068) models and [MeZO-SGD](https://arxiv.org/abs/2305.17333) as examples. For additional information, please refer to the section on [Supported Models and ZO methods](#-supported-models-zo-methods-and-tasks-support). ### 1. Using [MeZO-Runner](example/mezo_runner/) to Evaluate Fine-tuning Tasks Before running the following commands, please ensure that you have cloned the entire project. If you [installed](#️installation) ZO2 using option 2, you will need to run "git clone https://github.com/liangyuwang/zo2.git" to obtain the complete project, then navigate to the zo2 folder by "cd zo2". ```shell cd example/mezo_runner/ export CUDA_VISIBLE_DEVICES=0 MODEL=facebook/opt-2.7b TASK=SST2 MODE=ft LR=1e-7 EPS=1e-3 STEPS=20000 EVAL_STEPS=4000 bash mezo.sh ``` ### 2. Fine-Tuning HF Models with ZOTrainer / ZOSFTTrainer [[Trainer](./tutorial/huggingface.ipynb)] ```python from zo2 import ZOConfig, zo_hf_init from zo2.trainer.hf_transformers import ZOTrainer from transformers import TrainingArguments # Model and optimizer init zo_config = ZOConfig(method="mezo-sgd", zo2=True, offloading_device='cpu', working_device='cuda', lr=1e-5) with zo_hf_init(zo_config):     from transformers import OPTForCausalLM     model = OPTForCausalLM.from_pretrained("facebook/opt-125m")     model.zo_init(zo_config) training_args = TrainingArguments("test-trainer") trainer = ZOTrainer( model, args = training_args, train_dataset=..., # get training dataset eval_dataset=..., # get eval dataset data_collator=..., # get data_collator tokenizer=..., # use suitable tokenizer ... ) trainer.train() ``` ### 3. Train HF Models with Custom Training Loop [[demo](./tutorial/demo.ipynb)] ```python from zo2 import ZOConfig, zo_hf_init # Model and optimizer init zo_config = ZOConfig(method="mezo-sgd", zo2=True, offloading_device='cpu', working_device='cuda', lr=1e-5) with zo_hf_init(zo_config):     from transformers import OPTForCausalLM     model = OPTForCausalLM.from_pretrained("facebook/opt-125m")     model.zo_init(zo_config) # Training loop for i in range(max_training_step):     # Train     training_input_ids, training_labels = ...   # get training data batch     model.zo_train()     loss = model(input_ids=training_input_ids, labels=training_labels)     # Evaluate     eval_input_ids, eval_labels = ...   # get eval data batch     model.zo_eval()         output = model(input_ids=eval_input_ids, labels=eval_labels) ``` ## ✨ Tutorial Please refer to [tutorial](./tutorial/). ## 🤖 Supported Models, ZO methods, and Tasks - **Models**: * [NanoGPT](https://github.com/karpathy/build-nanogpt/blob/master/train_gpt2.py) (mainly for idea evaluation) * [Transformers](https://github.com/huggingface/transformers): [OPT](https://arxiv.org/abs/2205.01068) - **ZO methods**: * [MeZO-SGD](https://arxiv.org/abs/2305.17333) - **Tasks**: Please refer to [MeZO-Runner](example/mezo_runner/) ## 🧪 Test Please refer to [test](./test/). ## 🧭 Roadmap - [ ] Support more models like LLaMA and DeepSeek - [ ] Support more ZO methods - [ ] Support more offloading strategies (Disk offloading) ## 🚶 Contributing Feel free to submit issues and pull requests to improve the project! ## 📲 Contact * Liangyu Wang: liangyu.wang@kaust.edu.sa ## 📖 BibTeX ``` @article{wang2025zo2, title={ZO2: Scalable Zeroth-Order Fine-Tuning for Extremely Large Language Models with Limited GPU Memory}, author={Wang, Liangyu and Ren, Jie and Xu, Hang and Wang, Junxiao and Xie, Huanyi and Keyes, David E and Wang, Di}, journal={arXiv preprint arXiv:2503.12668}, year={2025} } ``` ================================================ FILE: env.yml ================================================ name: zo2 channels: - pytorch - nvidia - defaults dependencies: - _libgcc_mutex=0.1=main - _openmp_mutex=5.1=1_gnu - blas=1.0=mkl - brotli-python=1.0.9=py311h6a678d5_8 - bzip2=1.0.8=h5eee18b_6 - ca-certificates=2024.7.2=h06a4308_0 - certifi=2024.7.4=py311h06a4308_0 - charset-normalizer=3.3.2=pyhd3eb1b0_0 - cuda-cudart=12.1.105=0 - cuda-cupti=12.1.105=0 - cuda-libraries=12.1.0=0 - cuda-nvrtc=12.1.105=0 - cuda-nvtx=12.1.105=0 - cuda-opencl=12.6.37=0 - cuda-runtime=12.1.0=0 - cuda-version=12.6=3 - ffmpeg=4.3=hf484d3e_0 - filelock=3.13.1=py311h06a4308_0 - freetype=2.12.1=h4a9f257_0 - gmp=6.2.1=h295c915_3 - gmpy2=2.1.2=py311hc9b5ff0_0 - gnutls=3.6.15=he1e5248_0 - idna=3.7=py311h06a4308_0 - intel-openmp=2023.1.0=hdb19cb5_46306 - jinja2=3.1.4=py311h06a4308_0 - jpeg=9e=h5eee18b_3 - lame=3.100=h7b6447c_0 - lcms2=2.12=h3be6417_0 - ld_impl_linux-64=2.38=h1181459_1 - lerc=3.0=h295c915_0 - libcublas=12.1.0.26=0 - libcufft=11.0.2.4=0 - libcufile=1.11.0.15=0 - libcurand=10.3.7.37=0 - libcusolver=11.4.4.55=0 - libcusparse=12.0.2.55=0 - libdeflate=1.17=h5eee18b_1 - libffi=3.4.4=h6a678d5_1 - libgcc-ng=11.2.0=h1234567_1 - libgomp=11.2.0=h1234567_1 - libiconv=1.16=h5eee18b_3 - libidn2=2.3.4=h5eee18b_0 - libjpeg-turbo=2.0.0=h9bf148f_0 - libnpp=12.0.2.50=0 - libnvjitlink=12.1.105=0 - libnvjpeg=12.1.1.14=0 - libpng=1.6.39=h5eee18b_0 - libstdcxx-ng=11.2.0=h1234567_1 - libtasn1=4.19.0=h5eee18b_0 - libtiff=4.5.1=h6a678d5_0 - libunistring=0.9.10=h27cfd23_0 - libuuid=1.41.5=h5eee18b_0 - libwebp-base=1.3.2=h5eee18b_0 - llvm-openmp=14.0.6=h9e868ea_0 - lz4-c=1.9.4=h6a678d5_1 - markupsafe=2.1.3=py311h5eee18b_0 - mkl=2023.1.0=h213fc3f_46344 - mkl-service=2.4.0=py311h5eee18b_1 - mkl_fft=1.3.8=py311h5eee18b_0 - mkl_random=1.2.4=py311hdb19cb5_0 - mpc=1.1.0=h10f8cd9_1 - mpfr=4.0.2=hb69a4c5_1 - mpmath=1.3.0=py311h06a4308_0 - ncurses=6.4=h6a678d5_0 - nettle=3.7.3=hbbd107a_1 - networkx=3.3=py311h06a4308_0 - numpy=1.26.4=py311h08b1b3b_0 - numpy-base=1.26.4=py311hf175353_0 - openh264=2.1.1=h4ff587b_0 - openjpeg=2.5.2=he7f1fd0_0 - openssl=3.0.14=h5eee18b_0 - pillow=10.4.0=py311h5eee18b_0 - pip=24.0=py311h06a4308_0 - pysocks=1.7.1=py311h06a4308_0 - python=3.11.9=h955ad1f_0 - pytorch=2.4.0=py3.11_cuda12.1_cudnn9.1.0_0 - pytorch-cuda=12.1=ha16c6d3_5 - pytorch-mutex=1.0=cuda - pyyaml=6.0.1=py311h5eee18b_0 - readline=8.2=h5eee18b_0 - requests=2.32.3=py311h06a4308_0 - setuptools=72.1.0=py311h06a4308_0 - sqlite=3.45.3=h5eee18b_0 - sympy=1.12=py311h06a4308_0 - tbb=2021.8.0=hdb19cb5_0 - tk=8.6.14=h39e8969_0 - torchaudio=2.4.0=py311_cu121 - torchtriton=3.0.0=py311 - torchvision=0.19.0=py311_cu121 - typing_extensions=4.11.0=py311h06a4308_0 - urllib3=2.2.2=py311h06a4308_0 - wheel=0.43.0=py311h06a4308_0 - xz=5.4.6=h5eee18b_1 - yaml=0.2.5=h7b6447c_0 - zlib=1.2.13=h5eee18b_1 - zstd=1.5.5=hc292b87_2 - pip: - accelerate==1.6.0 - aiohappyeyeballs==2.3.5 - aiohttp==3.10.3 - aiosignal==1.3.1 - attrs==24.2.0 - datasets==3.5.1 - dill==0.3.8 - frozenlist==1.4.1 - fsspec==2024.5.0 - huggingface-hub==0.30.2 - joblib==1.4.2 - markdown-it-py==3.0.0 - mdurl==0.1.2 - multidict==6.0.5 - multiprocess==0.70.16 - nvidia-ml-py==12.570.86 - opt-einsum==3.3.0 - packaging==24.1 - pandas==2.2.2 - psutil==6.0.0 - pyarrow==17.0.0 - pyarrow-hotfix==0.6 - pygments==2.19.1 - python-dateutil==2.9.0.post0 - pytz==2024.1 - regex==2024.7.24 - rich==14.0.0 - safetensors==0.5.3 - scikit-learn==1.5.1 - scipy==1.14.0 - six==1.16.0 - threadpoolctl==3.5.0 - tokenizers==0.21.1 - tqdm==4.66.5 - transformers==4.51.3 - trl==0.17.0 - tzdata==2024.1 - xxhash==3.4.1 - yarl==1.9.4 ================================================ FILE: example/demo/train_zo2_with_hf_opt.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import sys sys.path.append("../zo2") import argparse from tqdm.auto import tqdm import torch from transformers import AutoTokenizer from zo2 import ( ZOConfig, zo_hf_init, ) from zo2.utils.utils import seed_everything # Hyper args = argparse.ArgumentParser() args.add_argument("--zo_method", type=str, default="zo2") args.add_argument("--eval", action="store_true") args.add_argument("--model_name", type=str, default="facebook/opt-2.7b") args.add_argument("--verbose", action="store_true") args.add_argument("--max_steps", type=int, default=100) args.add_argument("--lr", type=float, default=1e-5) args.add_argument("--weight_decay", type=float, default=1e-1) args.add_argument("--zo_eps", type=float, default=1e-3) args.add_argument("--seed", type=int, default=42) args.add_argument("--offloading_device", type=str, default="cpu") args.add_argument("--working_device", type=str, default="cuda:0") # For inference args.add_argument("--use_cache", action="store_true") args.add_argument("--max_new_tokens", type=int, default=50) args.add_argument("--temperature", type=float, default=1.0) args = args.parse_args() seed_everything(args.seed) # ZO steps zo_config = ZOConfig( method="mezo-sgd", zo2=args.zo_method=="zo2", lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device, ) # Load ZO model with zo_hf_init(zo_config): from transformers import OPTForCausalLM model = OPTForCausalLM.from_pretrained(args.model_name) model.zo_init(zo_config) if args.zo_method != "zo2": model = model.to(args.working_device) print(f"Check if zo2 init correctly: {hasattr(model, 'zo_training')}") # Prepare some data dataset = """ What is ZO2? ZO2 is an innovative framework specifically designed to enhance the fine-tuning of large language models (LLMs) using zeroth-order (ZO) optimization techniques and advanced offloading technologies. This framework is particularly tailored for setups with limited GPU memory, enabling the fine-tuning of models that were previously unmanageable due to hardware constraints. As the scale of Large Language Models (LLMs) continues to grow, reaching parameter counts in the hundreds of billions, managing GPU memory resources effectively becomes crucial. Efficient GPU memory management is crucial not only because it directly influences model performance and training speed, but also because GPU memory is both expensive and limited in quantity. However, this creates a significant challenge in handling ever-larger models within the physical constraints of current hardware technologies. CPU offloading has become a crucial technique for overcoming this challenge. It involves transferring computations and data from the GPU to the CPU, specifically targeting data or parameters that are less frequently accessed. By offloading these inactive tensors of the neural network, CPU offloading effectively alleviates the memory and computational pressures on GPUs. While CPU offloading has been commonly applied in inference to manage memory-intensive tasks, its application in training, especially fine-tuning, remains less explored. Recently, some works have tried to introduce CPU offloading into LLM training. However, they are typically constrained by the capabilities of first-order optimizers such as SGD and Adaptive Moment Estimation (AdamW), and limited GPU memory, restricting large-scale model scalability on single GPU systems. Using first-order optimizers introduces inefficiencies in CPU offloading: Multiple communication operations during the training of LLMs necessitate offloading the same data twice—once for each pass. This redundancy not only doubles the communication volume between the CPU and GPU but also introduces significant latency due to repetitive data transfers. Furthermore, both parameters and activations are required in the backward pass to complete gradient computations. This means that parameters and activation values must be offloaded during each forward pass and re-uploaded to the GPU for the backward pass, increasing the volume of data transferred, which severely impacts training throughput. On the other hand, zeroth-order (ZO) methods offer a novel approach to fine-tuning LLMs. These methods utilize dual forward passes to estimate parameter gradients and subsequently update parameters. This approach eliminates the traditional reliance on backward passes, thereby streamlining the training process by significantly reducing the number of computational steps required. Based on these observations, we conjecture that ZO's architecture is particularly well-suited for CPU offloading strategies. By eliminating backward passes and the need to store activation values, it can significantly reduce GPU memory demands through efficient parameter offloading. However, despite these advantages, ZO training via CPU offloading introduces new challenges, particularly in the realm of CPU-to-GPU communication. Transferring parameters between the CPU and GPU, which is crucial for maintaining gradient computation and model updates, becomes a critical bottleneck. Although ZO methods inherently extend computation times because of the dual forward passes, potentially allowing for better overlap between computation and communication, there remain significant inefficiencies. The necessity to upload parameters to the GPU for upcoming computations introduces a large volume of communications. To tackle the inefficiencies highlighted, we introduce ZO2, a novel framework specifically designed for ZO fine-tuning in LLMs with CPU offloading. This framework utilizes the unique dual forward pass architecture of ZO methods to optimize interactions between CPU and GPU, significantly enhancing both computational and communication efficiency. By building a high-performance dynamic scheduler, ZO2 achieves substantial overlaps in communication and computation. These innovations make it feasible to fine-tune extremely large models, such as the OPT-175B, with over 175 billion parameters, on a single GPU equipped with just 18GB of memory usage—a capability previously unattainable with conventional methods. Additionally, our efficient framework operates without any extra time cost and decreases in accuracy compared to standard ZO methodologies.""" tokenizer = AutoTokenizer.from_pretrained(args.model_name) data_batch = tokenizer(dataset, add_special_tokens=True, return_tensors='pt').input_ids.to(args.working_device) T = min(data_batch.shape[1] - 1, model.config.max_position_embeddings) print(f"Fine-tuning model {args.model_name} with {T} tokens dataset: \n{dataset}") # Training loop for i in tqdm(range(args.max_steps)): # train model.zo_train() loss = model(input_ids=data_batch, labels=data_batch) # eval if args.eval: if i==0: tqdm.write("Warning: please notice that ZO2 does not optimize the evaluation, so it may be very slow.") model.zo_eval() output = model(input_ids=data_batch, labels=data_batch) res = "Iteration {}, train loss: {}, projected grad: {}, eval loss: {}" tqdm.write(res.format(i, loss, model.opt.projected_grad, output["loss"])) else: res = "Iteration {}, train loss: {}, projected grad: {}" tqdm.write(res.format(i, loss, model.opt.projected_grad)) # inference print("Doing inference...") print("Warning: please notice that ZO2 does not optimize the inference, so it may be very slow.") model.zo_eval() prompt = "What is ZO2 and how ZO2 enhance the fine-tuning of large language models?" inputs = tokenizer(prompt, return_tensors='pt').to(args.working_device) inputs = {"input_ids": inputs.input_ids} for _ in tqdm(range(args.max_new_tokens)): outputs = model(**inputs, return_dict=True) next_token_logits = outputs.logits[:, -1, :] if args.temperature == 1.0: next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) else: scaled_logits = next_token_logits / args.temperature probs = torch.nn.functional.softmax(scaled_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) inputs = torch.cat([inputs["input_ids"], next_token], dim=-1) generated_text = tokenizer.decode(inputs[0]) inputs = {"input_ids": inputs} print(f"Question: {prompt}") print(f"Response: {generated_text[len(prompt)+4:]}...") ================================================ FILE: example/mezo_runner/README.md ================================================ # Example: Apply MeZO on LLMs Modified from [MeZO](https://github.com/princeton-nlp/MeZO/blob/main/large_models/README.md) ## Usage Use `run.py` for all functions (zero-shot/MeZO): ```bash python run.py {ARGUMENTS} ``` Please read [run.py](./run.py) for a complete list of arguments. We introduce some of the most important ones below. * `--num_train`: Number of training examples. * `--num_dev`: Number of validation examples. * `--num_test`: Number of testing examples. * `--model_name`: HuggingFace model name or path. * `--task_name`: Task name. * `--trainer`: can be `none` (zero-shot) or `zo` (MeZO). * `--train_as_classification`: turn this on for classification tasks (Cross Entropy over likelihood of each class' label words). Otherwise it is LM-style teacher forcing. * `--zo_eps`: MeZO hyperparameter epsilon. * `--zo_method`: choose zeroth-order methods. * `--zo_mode`: can be `zo` (on device) or `zo2` (offloading). * `--offloading_device`: offloading device. * `--working_device`: main working device. Example: 1. MeZO (full-parameter fine-tuning) ```bash # You can adjust the following model size and other hyperparameters. # OPT-2.7B MODEL=facebook/opt-2.7b TASK=SST2 MODE=ft LR=1e-7 EPS=1e-3 STEPS=20000 EVAL_STEPS=4000 bash mezo.sh # Qwen3-1.7B MODEL=Qwen/Qwen3-1.7B TASK=SST2 MODE=ft LR=1e-7 EPS=1e-3 STEPS=20000 EVAL_STEPS=4000 bash mezo.sh ``` ## Supported Tasks (See [tasks.py](./tasks.py)) - **SST2** - **Copa** - **BoolQ** - **MultiRC** - **CB** - **WIC** - **WSC** - **ReCoRD** - **RTE** - **SQuAD** - **DROP** ================================================ FILE: example/mezo_runner/metrics.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import numpy as np import collections import re import string from collections import Counter def normalize_answer(s): """Lower text and remove punctuation, articles and extra whitespace.""" def remove_articles(text): return re.sub(r'\b(a|an|the)\b', ' ', text) def white_space_fix(text): return ' '.join(text.split()) def remove_punc(text): exclude = set(string.punctuation) return ''.join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() return white_space_fix(remove_articles(remove_punc(lower(s)))) def calculate_metric(predictions, metric_name): if metric_name == "accuracy": if isinstance(predictions[0].correct_candidate, list): return np.mean([pred.predicted_candidate in pred.correct_candidate for pred in predictions]) else: return np.mean([pred.correct_candidate == pred.predicted_candidate for pred in predictions]) elif metric_name == "em": # For question answering return np.mean([any([normalize_answer(ans) == normalize_answer(pred.predicted_candidate) for ans in pred.correct_candidate]) for pred in predictions]) elif metric_name == "f1": # For question answering f1 = [] for pred in predictions: all_f1s = [] if pred.correct_candidate[0] == "CANNOTANSWER" or pred.correct_candidate[0] == "no answer": f1.append(int(normalize_answer(pred.correct_candidate[0]) == normalize_answer(pred.predicted_candidate))) else: for ans in pred.correct_candidate: prediction_tokens = normalize_answer(pred.predicted_candidate).split() ground_truth_tokens = normalize_answer(ans).split() common = Counter(prediction_tokens) & Counter(ground_truth_tokens) num_same = sum(common.values()) if num_same == 0: all_f1s.append(0) else: precision = 1.0 * num_same / len(prediction_tokens) recall = 1.0 * num_same / len(ground_truth_tokens) all_f1s.append((2 * precision * recall) / (precision + recall)) f1.append(max(all_f1s)) return np.mean(f1) def f1(pred, gold): """ This separate F1 function is used as non-differentiable metric for SQuAD """ if gold[0] == "CANNOTANSWER" or gold[0] == "no answer": return int(normalize_answer(gold[0]) == normalize_answer(pred)) else: all_f1s = [] for ans in gold: prediction_tokens = normalize_answer(pred).split() ground_truth_tokens = normalize_answer(ans).split() common = Counter(prediction_tokens) & Counter(ground_truth_tokens) num_same = sum(common.values()) if num_same == 0: all_f1s.append(0) else: precision = 1.0 * num_same / len(prediction_tokens) recall = 1.0 * num_same / len(ground_truth_tokens) all_f1s.append((2 * precision * recall) / (precision + recall)) return np.max(all_f1s) ================================================ FILE: example/mezo_runner/mezo.sh ================================================ MODEL=${MODEL:-facebook/opt-1.3b} MODEL_NAME=(${MODEL//\// }) MODEL_NAME="${MODEL_NAME[-1]}" BS=${BS:-16} LR=${LR:-1e-5} EPS=${EPS:-1e-3} SEED=${SEED:-0} TRAIN=${TRAIN:-1000} DEV=${DEV:-500} EVAL=${EVAL:-1000} STEPS=${STEPS:-20000} EVAL_STEPS=${EVAL_STEPS:-4000} MODE=${MODE:-ft} EXTRA_ARGS="" if [ "$MODE" == "prefix" ]; then EXTRA_ARGS="--prefix_tuning --num_prefix 5 --no_reparam --prefix_init_by_real_act" elif [ "$MODE" == "lora" ]; then EXTRA_ARGS="--lora" fi TAG=mezo-$MODE-$STEPS-$BS-$LR-$EPS-$SEED TASK_ARGS="" case $TASK in # For Copa, ReCoRD, SQuAD, DROP, we set --train_as_classification False; for others, set this flag to True CB) # It has <1000 training examples. Only use 100 for dev DEV=100 ;; Copa) # It has <1000 training examples. Only use 100 for dev DEV=100 TASK_ARGS="--train_as_classification False" ;; ReCoRD) TASK_ARGS="--train_as_classification False" ;; DROP) TASK_ARGS="--train_as_classification False" ;; SQuAD) TASK_ARGS="--train_as_classification False" ;; esac echo $TAG echo "BS: $BS" echo "LR: $LR" echo "EPS: $EPS" echo "SEED: $SEED" echo "TRAIN/EVAL STEPS: $STEPS/$EVAL_STEPS" echo "MODE: $MODE" echo "Extra args: $EXTRA_ARGS $TASK_ARGS" python run.py \ --model_name $MODEL \ --task_name $TASK \ --output_dir result/$TASK-${MODEL_NAME}-$TAG --tag $TAG --train_set_seed $SEED --num_train $TRAIN --num_dev $DEV --num_eval $EVAL --logging_steps 10 \ --max_steps $STEPS \ --trainer zo --load_float16 \ --learning_rate $LR --zo_eps $EPS --per_device_train_batch_size $BS --lr_scheduler_type "constant" \ --load_best_model_at_end --eval_strategy steps --save_strategy steps --save_total_limit 1 \ --eval_steps $EVAL_STEPS --save_steps $EVAL_STEPS \ --train_as_classification \ $EXTRA_ARGS \ $TASK_ARGS \ "$@" ================================================ FILE: example/mezo_runner/run.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 """ Modified from https://github.com/princeton-nlp/MeZO/blob/main/large_models/run.py """ import sys sys.path.append("../../../zo2") import logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) import argparse import time import tasks from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, Trainer, HfArgumentParser, Trainer, TrainingArguments, DataCollatorWithPadding, DataCollatorForTokenClassification from typing import Union, Optional import torch from torch.nn.parameter import Parameter import numpy as np from dataclasses import dataclass, is_dataclass, asdict from tqdm import tqdm from tasks import get_task import json import torch.nn.functional as F from torch.utils.data import Dataset from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP from metrics import calculate_metric from utils import * import random from zo2.trainer.hf_transformers.trainer import ZOTrainer from zo2 import zo_hf_init, ZOConfig @dataclass class OurArguments(TrainingArguments): # dataset and sampling strategy task_name: str = "SST2" # task name should match the string before Dataset in the Dataset class name. We support the following task_name: SST2, RTE, CB, BoolQ, WSC, WIC, MultiRC, Copa, ReCoRD, SQuAD, DROP # Number of examples num_train: int = 0 # ICL mode: number of demonstrations; training mode: number of training samples num_dev: int = None # (only enabled with training) number of development samples num_eval: int = None # number of evaluation samples num_train_sets: int = None # how many sets of training samples/demos to sample; if None and train_set_seed is None, then we will sample one set for each evaluation sample train_set_seed: int = None # designated seed to sample training samples/demos result_file: str = None # file name for saving performance; if None, then use the task name, model name, and config # Model loading model_name: str = "facebook/opt-125m" # HuggingFace model name load_float16: bool = False # load model parameters as float16 load_bfloat16: bool = False # load model parameters as bfloat16 load_int8: bool = False # load model parameters as int8 max_length: int = 2048 # max length the model can take no_auto_device: bool = False # do not load model by auto device; should turn this on when using FSDP # Calibration sfc: bool = False # whether to use SFC calibration icl_sfc: bool = False # whether to use SFC calibration for ICL samples # Training trainer: str = "none" ## options ## - none: no training -- for zero-shot or in-context learning (ICL) ## - regular: regular huggingface trainer -- for fine-tuning ## - zo: zeroth-order (MeZO) training only_train_option: bool = True # whether to only train the option part of the input train_as_classification: bool = False # take the log likelihood of all options and train as classification # MeZO zo_eps: float = 1e-3 # eps in MeZO # Prefix tuning prefix_tuning: bool = False # whether to use prefix tuning num_prefix: int = 5 # number of prefixes to use no_reparam: bool = True # do not use reparameterization trick prefix_init_by_real_act: bool = True # initialize prefix by real activations of random words # LoRA lora: bool = False # whether to use LoRA lora_alpha: int = 16 # alpha in LoRA lora_r: int = 8 # r in LoRA # Generation sampling: bool = False # whether to use sampling temperature: float = 1.0 # temperature for generation num_beams: int = 1 # number of beams for generation top_k: int = None # top-k for generation top_p: float = 0.95 # top-p for generation max_new_tokens: int = 50 # max number of new tokens to generate eos_token: str = "\n" # end of sentence token # Saving save_model: bool = False # whether to save the model no_eval: bool = False # whether to skip evaluation tag: str = "" # saving tag # Linear probing linear_probing: bool = False # whether to do linear probing lp_early_stopping: bool = False # whether to do early stopping in linear probing head_tuning: bool = False # head tuning: only tune the LM head # Untie emb/lm_head weights untie_emb: bool = False # untie the embeddings and LM head # Display verbose: bool = False # verbose output # Non-diff objective non_diff: bool = False # use non-differentiable objective (only support F1 for SQuAD for now) # Auto saving when interrupted save_on_interrupt: bool = False # save model when interrupted (useful for long training) # ZO2 added -> ZO2 configs zo_method: str = "mezo-sgd" zo_mode: str = "zo2" offloading_device: str = "cpu" working_device: str = "cuda:0" def parse_args(): parser = argparse.ArgumentParser() parser = HfArgumentParser(OurArguments) args = parser.parse_args_into_dataclasses()[0] print(args) return args def set_seed(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) class Framework: def __init__(self, args, task): self.args = args self.task = task self.model, self.tokenizer = self.load_model() def load_model(self): """ Load HuggingFace models """ with count_time("Loading model with FP%d" % (16 if self.args.load_float16 else 32)): free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3) config = AutoConfig.from_pretrained(self.args.model_name) if self.args.untie_emb: # Untie embeddings/LM head logger.warn("Untie embeddings and LM head") config.tie_word_embeddings = False # if self.args.head_tuning: # # Head tuning # from ht_opt import OPTForCausalLM # model = OPTForCausalLM.from_pretrained( # self.args.model_name, # config=config, # ) # elif self.args.no_auto_device: # # No auto device (use for FSDP) # model = AutoModelForCausalLM.from_pretrained( # self.args.model_name, # config=config, # ) # else: # # Auto device loading # torch_dtype = torch.float32 # if self.args.load_float16: # torch_dtype = torch.float16 # elif self.args.load_bfloat16: # torch_dtype = torch.bfloat16 # model = AutoModelForCausalLM.from_pretrained( # self.args.model_name, # config=config, # device_map='auto', # torch_dtype=torch_dtype, # max_memory={i: f'{free_in_GB-5}GB' for i in range(torch.cuda.device_count())}, # load_in_8bit=self.args.load_int8, # ) # ZO2 added -> init ZO2 model torch_dtype = torch.float32 if self.args.load_float16: torch_dtype = torch.float16 elif self.args.load_bfloat16: torch_dtype = torch.bfloat16 # Set up ZO configuration self.zo_config = ZOConfig( method="mezo-sgd", zo2=(self.args.zo_mode == "zo2"), lr=self.args.learning_rate, weight_decay=self.args.weight_decay, eps=self.args.zo_eps, offloading_device=self.args.offloading_device, working_device=self.args.working_device, ) # Initialize model within zo_hf_init context with zo_hf_init(self.zo_config): if "opt" in self.args.model_name: from transformers import OPTForCausalLM model = OPTForCausalLM.from_pretrained( self.args.model_name, config=config, torch_dtype=torch_dtype, max_memory={i: f'{free_in_GB-5}GB' for i in range(torch.cuda.device_count())}, load_in_8bit=self.args.load_int8, ) elif "Qwen3" in self.args.model_name: from transformers import Qwen3ForCausalLM model = Qwen3ForCausalLM.from_pretrained( self.args.model_name, config=config, torch_dtype=torch_dtype, max_memory={i: f'{free_in_GB-5}GB' for i in range(torch.cuda.device_count())}, load_in_8bit=self.args.load_int8, ) model.zo_init(self.zo_config) logger.info(f"Check if zo2 init correctly: {hasattr(model, 'zo_training')}") # If using a method other than zo2, move model to working device if self.args.zo_method != "zo2": model = model.to(self.args.working_device) model.eval() # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(self.args.model_name, use_fast=False) # HF tokenizer bug fix if "opt" in self.args.model_name: tokenizer.bos_token_id = 0 if "llama" in self.args.model_name: # LLaMA padding token tokenizer.pad_token_id = 0 # technically if "Qwen3" in self.args.model_name: # LLaMA padding token tokenizer.add_bos_token = False # Prefix tuning/LoRA if self.args.prefix_tuning: # from prefix import PrefixTuning # PrefixTuning(model, num_prefix=self.args.num_prefix, reparam=not self.args.no_reparam, float16=self.args.load_float16, init_by_real_act=self.args.prefix_init_by_real_act) raise NotImplementedError if self.args.lora: # from lora import LoRA # LoRA(model, r=self.args.lora_r, alpha=self.args.lora_alpha, float16=self.args.load_float16) raise NotImplementedError if self.args.head_tuning: # if model.config.model_type == "opt": # head_name = "lm_head" if self.args.untie_emb else "embed_tokens" # else: # raise NotImplementedError # for n, p in model.named_parameters(): # if head_name not in n: # p.requires_grad = False # else: # logger.info(f"Only tuning {n}") raise NotImplementedError return model, tokenizer def forward(self, input_ids, option_len=None, generation=False): """ Given input_ids and the length of the option, return the log-likelihood of each token in the option. For generation tasks, return the generated text. This function is only for inference """ input_ids = torch.tensor([input_ids]).to(self.model.device) if generation: args = self.args # Autoregressive generation outputs = self.model.generate( input_ids, do_sample=args.sampling, temperature=args.temperature, num_beams=args.num_beams, top_p=args.top_p, top_k=args.top_k, max_new_tokens=min(args.max_new_tokens, args.max_length - input_ids.size(1)), num_return_sequences=1, eos_token_id=[self.tokenizer.encode(args.eos_token, add_special_tokens=False)[-1], self.tokenizer.eos_token_id], ) # For generation, directly return the text output output_text = self.tokenizer.decode(outputs[0][input_ids.size(1):], skip_special_tokens=True).strip() return output_text else: with torch.inference_mode(): self.model.eval() logits = self.model(input_ids=input_ids).logits labels = input_ids[0, 1:] logits = logits[0, :-1] log_probs = F.log_softmax(logits, dim=-1) selected_log_probs = log_probs[torch.arange(len(labels)).to(labels.device), labels] selected_log_probs = selected_log_probs.cpu().detach() # Only return the option (candidate) part return selected_log_probs[-option_len:] def one_step_pred(self, train_samples, eval_sample, verbose=False): """ Return the prediction on the eval sample. In ICL, use train_samples as demonstrations """ verbose = verbose or self.args.verbose if verbose: logger.info("========= Example =========") logger.info(f"Candidate: {eval_sample.candidates}") logger.info(f"Correct candidate: {eval_sample.correct_candidate}") # Encode (add prompt and tokenize) the sample; if multiple-choice/classification, encode all candidates (options) encoded_candidates, option_lens = encode_prompt( self.task, self.task.get_template(), train_samples, eval_sample, self.tokenizer, max_length=self.args.max_length, generation=self.task.generation, max_new_tokens=self.args.max_new_tokens ) # Calibration if self.args.sfc or self.args.icl_sfc: sfc_encoded_candidates, sfc_option_lens = encode_prompt(self.task, self.task.get_template(), train_samples, eval_sample, self.tokenizer, max_length=self.args.max_length, sfc=self.args.sfc, icl_sfc=self.args.icl_sfc, generation=self.task.generation, max_new_tokens=self.args.max_new_tokens ) outputs = [] if self.task.generation: # For generation tasks, return the autoregressively-generated text output_text = self.forward(encoded_candidates[0], generation=True) if verbose: logger.info("=== Prompt ===") logger.info(self.tokenizer.decode(encoded_candidates[0])) logger.info(f"Output: {output_text}") return Prediction(correct_candidate=eval_sample.correct_candidate, predicted_candidate=output_text) else: # For classification/multiple-choice, calculate the probabilities of all candidates for candidate_id, encoded_candidate in enumerate(encoded_candidates): selected_log_probs = self.forward(encoded_candidate, option_len=option_lens[candidate_id]) if verbose: if candidate_id == 0: logger.info("=== Candidate %d ===" % candidate_id) logger.info(self.tokenizer.decode(encoded_candidate)) else: logger.info("=== Candidate %d (without context)===" % candidate_id) logger.info(self.tokenizer.decode(encoded_candidate).split(self.task.train_sep)[-1]) logger.info(f"Log probabilities of the option tokens: {selected_log_probs}") if self.args.sfc or self.args.icl_sfc: sfc_selected_log_probs = self.forward(sfc_encoded_candidates[candidate_id], option_len=sfc_option_lens[candidate_id]) if verbose: logger.info("=== Candidate %d (without context) SFC ===" % candidate_id) logger.info(self.tokenizer.decode(sfc_encoded_candidates[candidate_id]).split(self.task.train_sep)[-1]) logger.info(f"Log probabilities of the option tokens: {sfc_selected_log_probs}") outputs.append({"log_probs": selected_log_probs, "sfc_log_probs": sfc_selected_log_probs if self.args.sfc or self.args.icl_sfc else None}) if self.args.sfc or self.args.icl_sfc: # Calibrated probabilities (surface form competition; https://arxiv.org/pdf/2104.08315.pdf) # log p(candidate | input) = log p_lm(candidate | input) - log p_lm(candidate | sfc prompt) scores = [x['log_probs'].sum().item() - x['sfc_log_probs'].sum().item() for x in outputs] else: # (Default) length-normalized log probabilities # log p(candidate | input) = log p_lm(candidate | input) / |candidate #tokens| scores = [x['log_probs'].mean().item() for x in outputs] if verbose: logger.info(f"Prediction scores: {scores}") if isinstance(eval_sample.correct_candidate, list): # For some datasets there are multiple correct answers correct_candidate_id = [eval_sample.candidates.index(c) for c in eval_sample.correct_candidate] else: correct_candidate_id = eval_sample.candidates.index(eval_sample.correct_candidate) return Prediction(correct_candidate=correct_candidate_id, predicted_candidate=int(np.argmax(scores))) def evaluate(self, train_samples, eval_samples, one_train_set_per_eval_sample=False): """ Evaluate function. If one_train_set_per_eval_sample is True, then each eval sample has its own training (demonstration) set. """ if one_train_set_per_eval_sample: logger.info(f"There are {len(eval_samples)} validation samples and one train set per eval sample") else: logger.info(f"There are {len(train_samples)} training samples and {len(eval_samples)} validation samples") # Prediction loop predictions = [] for eval_id, eval_sample in enumerate(tqdm(eval_samples)): predictions.append( self.one_step_pred(train_samples[eval_id] if one_train_set_per_eval_sample else train_samples, eval_sample, verbose=(eval_id < 3)) ) # Calculate metrics metric_name = getattr(self.task, "metric_name", "accuracy") metrics = {metric_name: calculate_metric(predictions, metric_name)} return metrics def train(self, train_samples, eval_samples): """ Training function """ # Set tokenizer to left padding (so that all the options are right aligned) self.tokenizer.padding_side = "left" class HFDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] def _convert(samples): """ Convert samples to HF-compatible dataset """ data = [] for sample in samples: encoded_candidates, option_lens = encode_prompt( self.task, self.task.get_template(), [], sample, self.tokenizer, max_length=self.args.max_length, generation=self.task.generation, generation_with_gold=True, max_new_tokens=self.args.max_new_tokens ) if self.task.generation: correct_candidate_id = 0 elif isinstance(sample.correct_candidate, list): correct_candidate_id = sample.candidates.index(sample.correct_candidate[0]) else: correct_candidate_id = sample.candidates.index(sample.correct_candidate) if self.args.non_diff: # For non-differentiable objective, there is no teacher forcing thus the # current answer part is removed encoded_candidates[correct_candidate_id] = encoded_candidates[correct_candidate_id][:-option_lens[correct_candidate_id]] if self.args.train_as_classification: # For classification, we provide the label as the correct candidate id data.append([{"input_ids": encoded_candidates[_i], "labels": correct_candidate_id, "option_len": option_lens[_i], "num_options": len(sample.candidates)} for _i in range(len(encoded_candidates))]) elif self.args.only_train_option: # Otherwise, it is just LM-style teacher forcing if self.args.non_diff: # For non-differentiable objective, we need to provide the gold answer to calculate F1/acc data.append({"input_ids": encoded_candidates[correct_candidate_id], "labels": encoded_candidates[correct_candidate_id], "option_len": option_lens[correct_candidate_id], "gold": sample.correct_candidate}) else: data.append({"input_ids": encoded_candidates[correct_candidate_id], "labels": encoded_candidates[correct_candidate_id], "option_len": option_lens[correct_candidate_id]}) else: data.append({"input_ids": encoded_candidates[correct_candidate_id], "labels": encoded_candidates[correct_candidate_id]}) return data with count_time("Tokenizing training samples"): train_dataset = HFDataset(_convert(train_samples)) eval_dataset = HFDataset(_convert(eval_samples)) if self.args.only_train_option and not self.args.non_diff: # # If --only_train_option and not with a non-differentiable objective, we wrap the forward function # self.model.original_forward = self.model.forward # self.model.forward = forward_wrap_with_option_len.__get__(self.model, type(self.model)) # ZO2 added -> register custom loss functions self.model.zo_custom_train_loss_fn = custom_loss_fn_with_option_len self.model.zo_custom_eval_loss_fn = custom_loss_fn_with_option_len if self.args.non_diff: collator = NondiffCollator else: collator = DataCollatorForTokenClassification # ZO2 added -> trainer = ZOTrainer( model=self.model, args=self.args, train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=self.tokenizer, data_collator=DataCollatorWithPaddingAndNesting(self.tokenizer, pad_to_multiple_of=8) if self.args.train_as_classification else collator(self.tokenizer, pad_to_multiple_of=8), ) if self.args.save_on_interrupt: trainer.add_callback(SIGUSR1Callback()) # Resume training from a last checkpoint last_checkpoint = None from transformers.trainer_utils import get_last_checkpoint if os.path.isdir(self.args.output_dir) and not self.args.overwrite_output_dir: last_checkpoint = get_last_checkpoint(self.args.output_dir) if last_checkpoint is not None and self.args.resume_from_checkpoint is None: logger.info( f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." ) if self.args.resume_from_checkpoint is not None: last_checkpoint = self.args.resume_from_checkpoint trainer.train(resume_from_checkpoint=last_checkpoint) # Explicitly save the model if self.args.save_model: logger.warn("Save model..") trainer.save_model() # FSDP compatibility self.model = trainer.model # Reset the forward function for evaluation if self.args.only_train_option and not self.args.non_diff: # if type(self.model) == FSDP: # logger.info("This is an FSDP model now. Be careful when assigning back the original forward function") # self.model._fsdp_wrapped_module.forward = self.model._fsdp_wrapped_module.original_forward # else: # self.model.forward = self.model.original_forward # ZO2 added -> remove the custom loss functions for evaluation self.model.zo_custom_train_loss_fn = None self.model.zo_custom_eval_loss_fn = None def result_file_tag(args): """ Get the result file tag """ save_model_name = args.model_name.split("/")[-1] sfc_tag = "-sfc" if args.sfc else "" icl_sfc_tag = "-icl_sfc" if args.icl_sfc else "" sample_eval_tag = "-sampleeval%d" % args.num_eval if args.num_eval is not None else "" sample_train_tag = "-ntrain%d" % args.num_train if args.num_train > 0 else "" sample_dev_tag = "-ndev%d" % args.num_dev if args.num_dev is not None else "" customized_tag = f"-{args.tag}" if len(args.tag) > 0 else "" return f"{args.task_name}-{save_model_name}" + sfc_tag + icl_sfc_tag + sample_eval_tag + sample_train_tag + sample_dev_tag + customized_tag def main(): args = parse_args() set_seed(args.seed) task = get_task(args.task_name) train_sets = task.sample_train_sets(num_train=args.num_train, num_dev=args.num_dev, num_eval=args.num_eval, num_train_sets=args.num_train_sets, seed=args.train_set_seed) # Initialize trainer and load model framework = Framework(args, task) if args.train_set_seed is not None or args.num_train_sets is not None: # Eval samples share one (or multiple) training set(s) for train_set_id, train_samples in enumerate(train_sets): train_set_seed = train_set_id if args.train_set_seed is None else args.train_set_seed # Sample eval samples if args.num_eval is not None: eval_samples = task.sample_subset(data_split="valid", seed=train_set_seed, num=args.num_eval) else: eval_samples = task.valid_samples if args.trainer != "none": if args.num_dev is not None: # Dev samples dev_samples = train_samples[-args.num_dev:] train_samples = train_samples[:-args.num_dev] else: dev_samples = None # Training framework.train(train_samples, dev_samples if dev_samples is not None else eval_samples) if not args.no_eval: metrics = framework.evaluate([], eval_samples) # No in-context learning if there is training if dev_samples is not None: dev_metrics = framework.evaluate([], dev_samples) for m in dev_metrics: metrics["dev_" + m] = dev_metrics[m] else: assert args.num_dev is None # Zero-shot / in-context learning metrics = framework.evaluate(train_samples, eval_samples) if not args.no_eval: logger.info("===== Train set %d =====" % train_set_seed) logger.info(metrics) if args.local_rank <= 0: write_metrics_to_file(metrics, "result/" + result_file_tag(args) + f"-trainset{train_set_id}.json" if args.result_file is None else args.result_file) else: # For each eval sample, there is a training set. no training is allowed # This is for in-context learning (ICL) assert args.trainer == "none" if args.num_eval is not None: eval_samples = task.sample_subset(data_split="valid", seed=0, num=args.num_eval) else: eval_samples = task.valid_samples metrics = framework.evaluate(train_sets, eval_samples, one_train_set_per_eval_sample=True) logger.info(metrics) if args.local_rank <= 0: write_metrics_to_file(metrics, "result/" + result_file_tag(args) + "-onetrainpereval.json" if args.result_file is None else args.result_file) if __name__ == "__main__": main() ================================================ FILE: example/mezo_runner/tasks.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 """ Copied https://github.com/princeton-nlp/MeZO/blob/main/large_models/tasks.py """ from templates import * from utils import temp_seed import json import os from datasets import load_dataset from dataclasses import dataclass from typing import List, Union import string import random import datasets import sys import numpy as np import logging logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) def get_task(task_name): aa = task_name.split("__") if len(aa) == 2: task_group, subtask = aa else: task_group = aa[0] subtask = None class_ = getattr(sys.modules[__name__], f"{task_group}Dataset") instance = class_(subtask) return instance @dataclass class Sample: id: int = None data: dict = None correct_candidate: Union[str, List[str]] = None candidates: List[str] = None class Dataset: mixed_set = False train_sep = "\n\n" generation = False # whether this is a generation task def __init__(self, subtask=None, **kwargs) -> None: self.subtask = subtask def get_task_name(self): return self.subtask def load_dataset(): raise NotImplementedError def get_template(self, template_version=0): templates = {0: Template} return templates[template_version] def build_sample(self, example): return def sample_train_sets(self, num_train=32, num_dev=None, num_eval=None, num_train_sets=None, seed=None): if seed is not None: # one train/demo set using the designated seed seeds = [seed] elif num_train_sets is not None: # num_train_sets train/demo sets seeds = list(range(num_train_sets)) else: # one train/demo set per evaluation sample assert num_dev is None # not supported len_valid_samples = len(self.samples["valid"]) if num_eval is None else num_eval with temp_seed(0): seeds = np.random.randint(0, 10000, len_valid_samples) train_samples = [] for i, set_seed in enumerate(seeds): if self.mixed_set: raise NotImplementedError train_samples.append(self.sample_subset(data_split="valid", seed=set_seed, num=num_train, exclude=i)) else: if num_dev is not None: train_samples.append(self.sample_subset(data_split="train", seed=set_seed, num=num_train+num_dev)) # dev set is included at the end of train set if num_train + num_dev > len(self.samples["train"]): logger.warn("num_train + num_dev > available training examples") else: train_samples.append(self.sample_subset(data_split="train", seed=set_seed, num=num_train)) if num_dev is not None: logger.info(f"Sample train set {len(train_samples[-1])}/{len(self.samples['train'])}") logger.info(f"... including dev set {num_dev} samples") return train_samples def sample_subset(self, data_split="train", seed=0, num=100, exclude=None): with temp_seed(seed): samples = self.samples[data_split] lens = len(samples) index = np.random.permutation(lens).tolist()[:num if exclude is None else num+1] if exclude is not None and exclude in index: index.remove(exclude) else: index = index[:num] return [samples[i] for i in index] @property def valid_samples(self): return self.samples["valid"] class SST2Dataset(Dataset): train_sep = "\n\n" def __init__(self, subtask=None, **kwargs) -> None: self.load_dataset(subtask, **kwargs) def load_dataset(self, path, **kwargs): d = load_dataset('glue', 'sst2') train_d = d["train"] validation_d = d["validation"] train_samples = [self.build_sample(example) for example in train_d] valid_samples = [self.build_sample(example) for example in validation_d] self.samples = {"train": train_samples, "valid": valid_samples} # for generative tasks, candidates are [] def build_sample(self, example): label = int(example["label"]) return Sample(id=example["idx"], data=example, correct_candidate=label, candidates=[0, 1]) def get_template(self, template_version=0): return {0: SST2Template}[template_version]() class CopaDataset(Dataset): train_sep = "\n\n" mixed_set = False def __init__(self, subtask=None, **kwargs) -> None: self.load_dataset(subtask, **kwargs) def load_dataset(self, path, **kwargs): train_examples = load_dataset('super_glue', "copa")["train"] valid_examples = load_dataset('super_glue', "copa")["validation"] train_samples = [self.build_sample(example) for example in train_examples] valid_samples = [self.build_sample(example) for example in valid_examples] self.samples = {"train": train_samples, "valid": valid_samples} # for generative tasks, candidates are [] def build_sample(self, example): sample = \ Sample( id=example["idx"], data=example, candidates=[example["choice1"], example["choice2"]], correct_candidate=example[f"choice{example['label'] + 1}"], ) return sample def get_template(self, template_version=0): return {0: CopaTemplate}[template_version]() class BoolQDataset(Dataset): def __init__(self, subtask=None, **kwargs) -> None: self.load_dataset(subtask, **kwargs) def load_dataset(self, path, **kwargs): d = load_dataset("boolq") train_set = d["train"] valid_set = d["validation"] train_samples = [self.build_sample(example) for example in train_set] valid_samples = [self.build_sample(example) for example in valid_set] self.samples = {"train": train_samples, "valid": valid_samples} def build_sample(self, example): sample = \ Sample( data=example, candidates=["Yes", "No"], correct_candidate="Yes" if example["answer"] else "No", ) return sample def get_template(self, template_version=2): return {0: BoolQTemplate, 1: BoolQTemplateV2, 2: BoolQTemplateV3}[template_version]() class MultiRCDataset(Dataset): def __init__(self, subtask=None, **kwargs) -> None: self.load_dataset(subtask, **kwargs) def load_dataset(self, path, **kwargs): d = load_dataset("super_glue", "multirc") train_set = d["train"] valid_set = d["validation"] train_samples = [self.build_sample(example) for example in train_set] valid_samples = [self.build_sample(example) for example in valid_set] self.samples = {"train": train_samples, "valid": valid_samples} def build_sample(self, example): sample = \ Sample( data=example, candidates=[0, 1], correct_candidate=example['label'] ) return sample def get_template(self, template_version=0): return {0: MultiRCTemplate}[template_version]() class CBDataset(Dataset): def __init__(self, subtask=None, **kwargs) -> None: self.load_dataset(subtask, **kwargs) def load_dataset(self, path, **kwargs): d = load_dataset("super_glue", "cb") train_set = d["train"] valid_set = d["validation"] train_samples = [self.build_sample(example) for example in train_set] valid_samples = [self.build_sample(example) for example in valid_set] self.samples = {"train": train_samples, "valid": valid_samples} def build_sample(self, example): sample = \ Sample( data=example, candidates=[0, 1, 2], correct_candidate=example['label'] ) return sample def get_template(self, template_version=0): return {0: CBTemplate}[template_version]() class WICDataset(Dataset): def __init__(self, subtask=None, **kwargs) -> None: self.load_dataset(subtask, **kwargs) def load_dataset(self, path, **kwargs): d = load_dataset("super_glue", "wic") train_set = d["train"] valid_set = d["validation"] train_samples = [self.build_sample(example) for example in train_set] valid_samples = [self.build_sample(example) for example in valid_set] self.samples = {"train": train_samples, "valid": valid_samples} def build_sample(self, example): sample = \ Sample( data=example, candidates=[0, 1], correct_candidate=example['label'] ) return sample def get_template(self, template_version=0): return {0: WICTemplate}[template_version]() class WSCDataset(Dataset): def __init__(self, subtask=None, **kwargs) -> None: self.load_dataset(subtask, **kwargs) def load_dataset(self, path, **kwargs): d = load_dataset("super_glue", "wsc.fixed") train_set = d["train"] valid_set = d["validation"] train_samples = [self.build_sample(example) for example in train_set] valid_samples = [self.build_sample(example) for example in valid_set] self.samples = {"train": train_samples, "valid": valid_samples} def build_sample(self, example): sample = \ Sample( data=example, candidates=[0, 1], correct_candidate=example['label'] ) return sample def get_template(self, template_version=0): return {0: WSCTemplate}[template_version]() class ReCoRDDataset(Dataset): def __init__(self, subtask=None, **kwargs) -> None: self.load_dataset(subtask, **kwargs) def load_dataset(self, path, **kwargs): d = load_dataset("super_glue", "record") train_set = d["train"] valid_set = d["validation"] train_samples = [self.build_sample(example) for example in train_set] valid_samples = [self.build_sample(example) for example in valid_set] self.samples = {"train": train_samples, "valid": valid_samples} def build_sample(self, example): sample = \ Sample( data=example, candidates=example['entities'], correct_candidate=example['answers'] ) return sample def get_template(self, template_version=0): return {0: ReCoRDTemplateGPT3}[template_version]() class RTEDataset(Dataset): def __init__(self, subtask=None, **kwargs) -> None: self.load_dataset(subtask, **kwargs) def load_dataset(self, path, **kwargs): d = load_dataset("super_glue", "rte") train_set = d["train"] valid_set = d["validation"] train_samples = [self.build_sample(example) for example in train_set] valid_samples = [self.build_sample(example) for example in valid_set] self.samples = {"train": train_samples, "valid": valid_samples} def build_sample(self, example): sample = \ Sample( data=example, candidates=[0, 1], correct_candidate=example['label'] ) return sample def get_template(self, template_version=0): return {0: RTETemplate}[template_version]() class SQuADDataset(Dataset): metric_name = "f1" generation = True def __init__(self, subtask=None, **kwargs) -> None: self.load_dataset() def load_dataset(self): dataset = load_dataset("squad") train_examples = dataset["train"] valid_examples = dataset["validation"] train_samples = [self.build_sample(example, idx) for idx, example in enumerate(train_examples)] valid_samples = [self.build_sample(example, idx) for idx, example in enumerate(valid_examples)] self.samples = {"train": train_samples, "valid": valid_samples} # for generative tasks, candidates are [] def build_sample(self, example, idx): answers = example['answers']['text'] assert len(answers) > 0 return Sample( id=idx, data={ "title": example['title'], "context": example['context'], "question": example['question'], "answers": answers }, candidates=None, correct_candidate=answers ) def get_template(self, template_version=0): return {0: SQuADv2Template}[template_version]() class DROPDataset(Dataset): metric_name = "f1" generation = True def __init__(self, subtask=None, **kwargs) -> None: self.load_dataset() def load_dataset(self): dataset = load_dataset("drop") train_examples = dataset["train"] valid_examples = dataset["validation"] train_samples = [self.build_sample(example, idx) for idx, example in enumerate(train_examples)] valid_samples = [self.build_sample(example, idx) for idx, example in enumerate(valid_examples)] self.samples = {"train": train_samples, "valid": valid_samples} # for generative tasks, candidates are [] def build_sample(self, example, idx): answers = example['answers_spans']['spans'] assert len(answers) > 0 return Sample( id=idx, data={ "context": example['passage'], "question": example['question'], "answers": answers }, candidates=None, correct_candidate=answers ) def get_template(self, template_version=0): return {0: DROPTemplate}[template_version]() ================================================ FILE: example/mezo_runner/templates.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 """ Copied https://github.com/princeton-nlp/MeZO/blob/main/large_models/templates.py """ class Template: def encode(self, sample): """ Return prompted version of the example (without the answer/candidate) """ raise NotImplementedError def verbalize(self, sample, candidate): """ Return the prompted version of the example (with the answer/candidate) """ return candidate def encode_sfc(self, sample): """ Same as encode, but for SFC (calibration) -- this usually means the input is not included """ return "" def verbalize_sfc(self, sample, candidate): """ Same as verbalize, but for SFC (calibration) -- this usually means the input is not included """ return candidate class SST2Template(Template): verbalizer = {0: "terrible", 1: "great"} def encode(self, sample): text = sample.data["sentence"].strip() return f"{text} It was" def verbalize(self, sample, candidate): text = sample.data["sentence"].strip() return f"{text} It was {self.verbalizer[candidate]}" def encode_sfc(self, sample): return f" It was" def verbalize_sfc(self, sample, candidate): return f" It was {self.verbalizer[candidate]}" class CopaTemplate(Template): capitalization: str = "correct" effect_conj: str = " so " cause_conj: str = " because " def get_conjucture(self, sample): if sample.data["question"] == "effect": conjunction = self.effect_conj elif sample.data["question"] == "cause": conjunction = self.cause_conj else: raise NotImplementedError return conjunction def get_prompt(self, sample): premise = sample.data["premise"].rstrip() if premise.endswith("."): # TODO Add other scripts with different punctuation premise = premise[:-1] conjunction = self.get_conjucture(sample) prompt = premise + conjunction if self.capitalization == "upper": prompt = prompt.upper() elif self.capitalization == "lower": prompt = prompt.lower() return prompt def encode(self, sample): prompt = self.get_prompt(sample) return prompt def capitalize(self, c): if self.capitalization == "correct": words = c.split(" ") if words[0] != "I": words[0] = words[0].lower() return " ".join(words) elif self.capitalization == "bug": return c elif self.capitalization == "upper": return c.upper() elif self.capitalization == "lower": return c.lower() else: raise NotImplementedError def verbalize(self, sample, candidate): prompt = self.get_prompt(sample) return prompt + self.capitalize(candidate) def encode_sfc(self, sample): conjunction = self.get_conjucture(sample) return conjunction.strip() def verbalize_sfc(self, sample, candidate): conjunction = self.get_conjucture(sample) sfc_prompt = conjunction.strip() + " " + self.capitalize(candidate) return sfc_prompt class BoolQTemplate(Template): def encode(self, sample): passage = sample.data["passage"] question = sample.data["question"] if not question.endswith("?"): question = question + "?" question = question[0].upper() + question[1:] return f"{passage} {question}" def verbalize(self, sample, candidate): passage = sample.data["passage"] question = sample.data["question"] if not question.endswith("?"): question = question + "?" question = question[0].upper() + question[1:] return f"{passage} {question} {candidate}" def encode_sfc(self, sample): return "" def verbalize_sfc(self, sample, candidate): return candidate class BoolQTemplateV2(Template): def encode(self, sample): passage = sample.data["passage"] question = sample.data["question"] if not question.endswith("?"): question = question + "?" question = question[0].upper() + question[1:] return f"{passage} {question}\\n\\n" def verbalize(self, sample, candidate): passage = sample.data["passage"] question = sample.data["question"] if not question.endswith("?"): question = question + "?" question = question[0].upper() + question[1:] return f"{passage} {question}\\n\\n{candidate}" def encode_sfc(self, sample): return "" def verbalize_sfc(self, sample, candidate): return candidate class BoolQTemplateV3(Template): def encode(self, sample): passage = sample.data["passage"] question = sample.data["question"] if not question.endswith("?"): question = question + "?" question = question[0].upper() + question[1:] return f"{passage} {question}\n" def verbalize(self, sample, candidate): passage = sample.data["passage"] question = sample.data["question"] if not question.endswith("?"): question = question + "?" question = question[0].upper() + question[1:] return f"{passage} {question}\n{candidate}" def encode_sfc(self, sample): return "" def verbalize_sfc(self, sample, candidate): return candidate class MultiRCTemplate(Template): # From PromptSource 1 verbalizer = {0: "No", 1: "Yes"} def encode(self, sample): paragraph = sample.data["paragraph"] question = sample.data["question"] answer = sample.data["answer"] return f"{paragraph}\nQuestion: {question}\nI found this answer \"{answer}\". Is that correct? Yes or No?\n" def verbalize(self, sample, candidate): paragraph = sample.data["paragraph"] question = sample.data["question"] answer = sample.data["answer"] return f"{paragraph}\nQuestion: {question}\nI found this answer \"{answer}\". Is that correct? Yes or No?\n{self.verbalizer[candidate]}" def encode_sfc(self, sample): return f"" def verbalize_sfc(self, sample, candidate): return f"{self.verbalizer[candidate]}" class CBTemplate(Template): # From PromptSource 1 verbalizer = {0: "Yes", 1: "No", 2: "Maybe"} def encode(self, sample): premise = sample.data["premise"] hypothesis = sample.data["hypothesis"] return f"Suppose {premise} Can we infer that \"{hypothesis}\"? Yes, No, or Maybe?\n" def verbalize(self, sample, candidate): premise = sample.data["premise"] hypothesis = sample.data["hypothesis"] return f"Suppose {premise} Can we infer that \"{hypothesis}\"? Yes, No, or Maybe?\n{self.verbalizer[candidate]}" def encode_sfc(self, sample): return f"" def verbalize_sfc(self, sample, candidate): return f"{self.verbalizer[candidate]}" class WICTemplate(Template): # From PromptSource 1 verbalizer = {0: "No", 1: "Yes"} def encode(self, sample): sent1 = sample.data["sentence1"] sent2 = sample.data["sentence2"] word = sample.data["word"] return f"Does the word \"{word}\" have the same meaning in these two sentences? Yes, No?\n{sent1}\n{sent2}\n" def verbalize(self, sample, candidate): sent1 = sample.data["sentence1"] sent2 = sample.data["sentence2"] word = sample.data["word"] return f"Does the word \"{word}\" have the same meaning in these two sentences? Yes, No?\n{sent1}\n{sent2}\n{self.verbalizer[candidate]}" def encode_sfc(self, sample): return f"" def verbalize_sfc(self, sample, candidate): return f"{self.verbalizer[candidate]}" class WSCTemplate(Template): # From PromptSource 1 verbalizer = {0: "No", 1: "Yes"} def encode(self, sample): text = sample.data['text'] span1 = sample.data['span1_text'] span2 = sample.data['span2_text'] return f"{text}\nIn the previous sentence, does the pronoun \"{span2.lower()}\" refer to {span1}? Yes or No?\n" def verbalize(self, sample, candidate): text = sample.data['text'] span1 = sample.data['span1_text'] span2 = sample.data['span2_text'] return f"{text}\nIn the previous sentence, does the pronoun \"{span2.lower()}\" refer to {span1}? Yes or No?\n{self.verbalizer[candidate]}" def encode_sfc(self, sample): return f"" def verbalize_sfc(self, sample, candidate): return f"{self.verbalizer[candidate]}" class ReCoRDTemplate(Template): # From PromptSource 1 but modified def encode(self, sample): passage = sample.data['passage'] query = sample.data['query'] return f"{passage}\n{query}\nQuestion: what is the \"@placeholder\"\nAnswer:" def verbalize(self, sample, candidate): passage = sample.data['passage'] query = sample.data['query'] return f"{passage}\n{query}\nQuestion: what is the \"@placeholder\"\nAnswer: {candidate}" def encode_sfc(self, sample): return f"Answer:" def verbalize_sfc(self, sample, candidate): return f"Answer: {candidate}" class ReCoRDTemplateGPT3(Template): # From PromptSource 1 but modified def encode(self, sample): passage = sample.data['passage'].replace("@highlight\n", "- ") return f"{passage}\n-" def verbalize(self, sample, candidate): passage = sample.data['passage'].replace("@highlight\n", "- ") query = sample.data['query'].replace("@placeholder", candidate[0] if isinstance(candidate, list) else candidate) return f"{passage}\n- {query}" # passage = sample.data['passage'] # query = sample.data['query'] # return f"{passage}\n{query}\nQuestion: what is the \"@placeholder\"\nAnswer: {candidate}" def encode_sfc(self, sample): return f"-" def verbalize_sfc(self, sample, candidate): query = sample.data['query'].replace("@placeholder", candidate[0] if isinstance(candidate, list) else candidate) return f"- {query}" class RTETemplate(Template): # From PromptSource 1 verbalizer={0: "Yes", 1: "No"} def encode(self, sample): premise = sample.data['premise'] hypothesis = sample.data['hypothesis'] return f"{premise}\nDoes this mean that \"{hypothesis}\" is true? Yes or No?\n" def verbalize(self, sample, candidate): premise = sample.data['premise'] hypothesis = sample.data['hypothesis'] return f"{premise}\nDoes this mean that \"{hypothesis}\" is true? Yes or No?\n{self.verbalizer[candidate]}" def encode_sfc(self, sample): return f"" def verbalize_sfc(self, sample, candidate): return f"{self.verbalizer[candidate]}" class SQuADv2Template(Template): def encode(self, sample): question = sample.data['question'].strip() title = sample.data['title'] context = sample.data['context'] answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one return f"Title: {title}\nContext: {context}\nQuestion: {question}\nAnswer:" def verbalize(self, sample, candidate): question = sample.data['question'].strip() title = sample.data['title'] context = sample.data['context'] answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one return f"Title: {title}\nContext: {context}\nQuestion: {question}\nAnswer: {answer}\n" def encode_sfc(self, sample): raise NotImplementedError def verbalize_sfc(self, sample, candidate): raise NotImplementedError class DROPTemplate(Template): def encode(self, sample): question = sample.data['question'].strip() # title = sample.data['title'] context = sample.data['context'] answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one return f"Passage: {context}\nQuestion: {question}\nAnswer:" def verbalize(self, sample, candidate): question = sample.data['question'].strip() # title = sample.data['title'] context = sample.data['context'] answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one return f"Passage: {context}\nQuestion: {question}\nAnswer: {answer}\n" def encode_sfc(self, sample): raise NotImplementedError def verbalize_sfc(self, sample, candidate): raise NotImplementedError ================================================ FILE: example/mezo_runner/utils.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 """ Modified from https://github.com/princeton-nlp/MeZO/blob/main/large_models/utils.py """ import json import os import contextlib from typing import Optional, Union import numpy as np from dataclasses import dataclass, is_dataclass, asdict import logging import time from torch.nn import CrossEntropyLoss import torch.nn.functional as F from transformers.modeling_outputs import CausalLMOutputWithPast import torch from transformers.utils import PaddingStrategy from transformers import PreTrainedTokenizerBase from transformers.data.data_collator import DataCollatorMixin import transformers from typing import Optional, Union, List, Dict, Any import signal from subprocess import call from collections.abc import Mapping from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union InputDataClass = NewType("InputDataClass", Any) from dataclasses import dataclass from transformers.tokenization_utils_base import PreTrainedTokenizerBase logger = logging.getLogger(__name__) def custom_loss_fn_with_option_len(self, input_ids, logits, labels, option_len=None, num_options=None): """ Modified from below 'forward_wrap_with_option_len'. """ loss = None # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() # Here we use input_ids (which should always = labels) bc sometimes labels are correct candidate IDs shift_labels = torch.clone(input_ids)[..., 1:].contiguous() shift_labels[shift_labels == self.config.pad_token_id] = -100 # Apply option len (do not calculate loss on the non-option part) for _i, _len in enumerate(option_len): shift_labels[_i, :-_len] = -100 # Calculate the loss loss_fct = CrossEntropyLoss(ignore_index=-100) if num_options is not None: # Train as a classification tasks log_probs = F.log_softmax(shift_logits, dim=-1) mask = shift_labels != -100 # Option part shift_labels[~mask] = 0 # So that it doesn't mess up with indexing selected_log_probs = torch.gather(log_probs, dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(-1) # (bsz x num_options, len) selected_log_probs = (selected_log_probs * mask).sum(-1) / mask.sum(-1) # (bsz x num_options) if any([x != num_options[0] for x in num_options]): # Multi choice tasks with different number of options loss = 0 start_id = 0 count = 0 while start_id < len(num_options): end_id = start_id + num_options[start_id] _logits = selected_log_probs[start_id:end_id].unsqueeze(0) # (1, num_options) _labels = labels[start_id:end_id][0].unsqueeze(0) # (1) loss = loss_fct(_logits, _labels) + loss count += 1 start_id = end_id loss = loss / count else: num_options = num_options[0] selected_log_probs = selected_log_probs.view(-1, num_options) # (bsz, num_options) labels = labels.view(-1, num_options)[:, 0] # Labels repeat so we only take the first one loss = loss_fct(selected_log_probs, labels) else: loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) return loss def forward_wrap_with_option_len(self, input_ids=None, labels=None, option_len=None, num_options=None, return_dict=None, **kwargs): """ This is to replace the original forward function of Transformer models to enable: (1) Partial target sequence: loss will only be calculated on part of the sequence (2) Classification-style training: a classification loss (CE) will be calculated over several options Input: - input_ids, labels: same as the original forward function - option_len: a list of int indicating the option lengths, and loss will be calculated only on the last option_len tokens - num_options: a list of int indicating the number of options for each example (this will be #label words for classification tasks and #choices for multiple choice tasks), and a classification loss will be calculated. """ outputs = self.original_forward(input_ids=input_ids, **kwargs) if labels is None: return outputs logits = outputs.logits loss = None # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() # Here we use input_ids (which should always = labels) bc sometimes labels are correct candidate IDs shift_labels = torch.clone(input_ids)[..., 1:].contiguous() shift_labels[shift_labels == self.config.pad_token_id] = -100 # Apply option len (do not calculate loss on the non-option part) for _i, _len in enumerate(option_len): shift_labels[_i, :-_len] = -100 # Calculate the loss loss_fct = CrossEntropyLoss(ignore_index=-100) if num_options is not None: # Train as a classification tasks log_probs = F.log_softmax(shift_logits, dim=-1) mask = shift_labels != -100 # Option part shift_labels[~mask] = 0 # So that it doesn't mess up with indexing selected_log_probs = torch.gather(log_probs, dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(-1) # (bsz x num_options, len) selected_log_probs = (selected_log_probs * mask).sum(-1) / mask.sum(-1) # (bsz x num_options) if any([x != num_options[0] for x in num_options]): # Multi choice tasks with different number of options loss = 0 start_id = 0 count = 0 while start_id < len(num_options): end_id = start_id + num_options[start_id] _logits = selected_log_probs[start_id:end_id].unsqueeze(0) # (1, num_options) _labels = labels[start_id:end_id][0].unsqueeze(0) # (1) loss = loss_fct(_logits, _labels) + loss count += 1 start_id = end_id loss = loss / count else: num_options = num_options[0] selected_log_probs = selected_log_probs.view(-1, num_options) # (bsz, num_options) labels = labels.view(-1, num_options)[:, 0] # Labels repeat so we only take the first one loss = loss_fct(selected_log_probs, labels) else: loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def encode_prompt(task, template, train_samples, eval_sample, tokenizer, max_length, sfc=False, icl_sfc=False, generation=False, generation_with_gold=False, max_new_tokens=None): """ Encode prompts for eval_sample Input: - task, template: task and template class - train_samples, eval_sample: demonstrations and the actual sample - tokenizer, max_length: tokenizer and max length - sfc: generate prompts for calibration (surface form competition; https://arxiv.org/abs/2104.08315) - icl_sfc: generate prompts for ICL version calibration - generation: whether it is an generation task - generation_with_gold: whether to include the generation-task gold answers (for training) - max_new_tokens: max number of new tokens to generate so that we can save enough space (only for generation tasks) Output: - encodings: a list of N lists of tokens. N is the number of options for classification/multiple-choice. - option_lens: a list of N integers indicating the number of option tokens. """ # Demonstrations for ICL train_prompts = [template.verbalize(sample, sample.correct_candidate).strip() for sample in train_samples] train_prompts = task.train_sep.join(train_prompts).strip() # sfc or icl_sfc indicates that this example is used for calibration if sfc or icl_sfc: encode_fn = template.encode_sfc; verbalize_fn = template.verbalize_sfc else: encode_fn = template.encode; verbalize_fn = template.verbalize unverbalized_eval_prompt = encode_fn(eval_sample).strip(' ') if not generation: # We generate one prompt for each candidate (different classes in classification) # or different choices in multiple-choice tasks verbalized_eval_prompts = [verbalize_fn(eval_sample, cand).strip(' ') for cand in eval_sample.candidates] unverbalized_eval_prompt_length = len(tokenizer.encode(unverbalized_eval_prompt)) option_lens = [(len(tokenizer.encode(verbalized_eval_prompt)) - unverbalized_eval_prompt_length) for verbalized_eval_prompt in verbalized_eval_prompts] if sfc: # Without demonstrations final_prompts = verbalized_eval_prompts else: # With demonstrations final_prompts = [(train_prompts + task.train_sep + eval_prompt).lstrip().strip(' ') for eval_prompt in verbalized_eval_prompts] else: assert not sfc and not icl_sfc, "Generation tasks do not support SFC" if generation_with_gold: verbalized_eval_prompts = [verbalize_fn(eval_sample, eval_sample.correct_candidate)] unverbalized_eval_prompt_length = len(tokenizer.encode(unverbalized_eval_prompt)) option_lens = [(len(tokenizer.encode(verbalized_eval_prompt)) - unverbalized_eval_prompt_length) for verbalized_eval_prompt in verbalized_eval_prompts] final_prompts = [(train_prompts + task.train_sep + eval_prompt).lstrip().strip(' ') for eval_prompt in verbalized_eval_prompts] else: option_lens = [0] final_prompts = [(train_prompts + task.train_sep + unverbalized_eval_prompt).lstrip().strip(' ')] # Tokenize encodings = [tokenizer.encode(final_prompt) for final_prompt in final_prompts] # Truncate (left truncate as demonstrations are less important) if generation and max_new_tokens is not None: max_length = max_length - max_new_tokens if any([len(encoding) > max_length for encoding in encodings]): logger.warn("Exceed max length") if tokenizer.add_bos_token: encodings = [encoding[0:1] + encoding[1:][-(max_length-1):] for encoding in encodings] else: encodings = [encoding[-max_length:] for encoding in encodings] return encodings, option_lens @dataclass class ICLCollator: """ Collator for ICL """ tokenizer: PreTrainedTokenizerBase def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: if not isinstance(features[0], Mapping): features = [vars(f) for f in features] first = features[0] batch = {} pad_id = self.tokenizer.pad_token_id pad_ids = {"input_ids": pad_id, "attention_mask": 0, "sfc_input_ids": pad_id, "sfc_attention_mask": 0, "labels": pad_id} for key in first: pp = pad_ids[key] lens = [len(f[key]) for f in features] max_len = max(lens) feature = np.stack([np.pad(f[key], (0, max_len - lens[i]), "constant", constant_values=(0, pp)) for i, f in enumerate(features)]) padded_feature = torch.from_numpy(feature).long() batch[key] = padded_feature return batch @dataclass class DataCollatorWithPaddingAndNesting: """ Collator for training """ tokenizer: PreTrainedTokenizerBase padding: Union[bool, str, PaddingStrategy] = True max_length: Optional[int] = None pad_to_multiple_of: Optional[int] = None return_tensors: str = "pt" def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: features = [ff for f in features for ff in f] batch = self.tokenizer.pad( features, padding=self.padding, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=self.return_tensors, ) if "label" in batch: batch["labels"] = batch["label"] del batch["label"] if "label_ids" in batch: batch["labels"] = batch["label_ids"] del batch["label_ids"] return batch @dataclass class NondiffCollator(DataCollatorMixin): """ Collator for non-differentiable objectives """ tokenizer: PreTrainedTokenizerBase padding: Union[bool, str, PaddingStrategy] = True max_length: Optional[int] = None pad_to_multiple_of: Optional[int] = None label_pad_token_id: int = -100 return_tensors: str = "pt" def torch_call(self, features): import torch label_name = "label" if "label" in features[0].keys() else "labels" labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None no_labels_features = [{k: v for k, v in feature.items() if k != label_name and k != "gold"} for feature in features] batch = self.tokenizer.pad( no_labels_features, padding=self.padding, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors="pt", ) if labels is None: return batch sequence_length = batch["input_ids"].shape[1] padding_side = self.tokenizer.padding_side def to_list(tensor_or_iterable): if isinstance(tensor_or_iterable, torch.Tensor): return tensor_or_iterable.tolist() return list(tensor_or_iterable) if padding_side == "right": batch[label_name] = [ to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels ] else: batch[label_name] = [ [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels ] batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64) if "gold" in features[0]: batch["gold"] = [feature["gold"] for feature in features] return batch class SIGUSR1Callback(transformers.TrainerCallback): """ This callback is used to save the model when a SIGUSR1 signal is received (SLURM stop signal or a keyboard interruption signal). """ def __init__(self) -> None: super().__init__() self.signal_received = False signal.signal(signal.SIGUSR1, self.handle_signal) signal.signal(signal.SIGINT, self.handle_signal) logger.warn("Handler registered") def handle_signal(self, signum, frame): self.signal_received = True logger.warn("Signal received") def on_step_end(self, args, state, control, **kwargs): if self.signal_received: control.should_save = True control.should_training_stop = True def on_train_end(self, args, state, control, **kwargs): if self.signal_received: exit(0) @dataclass class Prediction: correct_candidate: Union[int, str] predicted_candidate: Union[int, str] @contextlib.contextmanager def count_time(name): logger.info("%s..." % name) start_time = time.time() try: yield finally: logger.info("Done with %.2fs" % (time.time() - start_time)) @contextlib.contextmanager def temp_seed(seed): state = np.random.get_state() np.random.seed(seed) try: yield finally: np.random.set_state(state) class EnhancedJSONEncoder(json.JSONEncoder): def default(self, o): if is_dataclass(o): return asdict(o) return super().default(o) def write_predictions_to_file(final_preds, output): with open(output, "w") as f: for pred in final_preds: f.write(json.dumps(pred, cls=EnhancedJSONEncoder) + "\n") def write_metrics_to_file(metrics, output): json.dump(metrics, open(output, "w"), cls=EnhancedJSONEncoder, indent=4) ================================================ FILE: requirements.txt ================================================ brotli==1.0.9 certifi==2024.7.4 charset-normalizer==3.3.2 filelock==3.13.1 idna==3.7 Jinja2==3.1.4 MarkupSafe==2.1.3 numpy==1.26.4 Pillow==10.4.0 PySocks==1.7.1 PyYAML==6.0.1 requests==2.32.3 rich==14.0.0 setuptools==72.1.0 urllib3==2.2.2 wheel==0.43.0 accelerate==1.6.0 datasets==3.5.1 aiohttp==3.10.3 aiosignal==1.3.1 attrs==24.2.0 dill==0.3.8 frozenlist==1.4.1 fsspec==2024.5.0 huggingface-hub==0.24.5 joblib==1.4.2 multidict==6.0.5 multiprocess==0.70.16 opt-einsum==3.3.0 packaging==24.1 pandas==2.2.2 psutil==6.0.0 pyarrow==17.0.0 pyarrow-hotfix==0.6 python-dateutil==2.9.0.post0 pytz==2024.1 regex==2024.7.24 scikit-learn==1.5.1 scipy==1.14.0 six==1.16.0 threadpoolctl==3.5.0 tokenizers==0.21.1 tqdm==4.66.5 transformers==4.51.3 tzdata==2024.1 xxhash==3.4.1 yarl==1.9.4 nvidia-ml-py==12.570.86 trl==0.17.0 safetensors==0.5.2 ================================================ FILE: script/add-copyright.py ================================================ import os import datetime import logging current_year = datetime.datetime.now().year owner = "liangyuwang" logging.basicConfig(filename='license_addition_errors.log', level=logging.ERROR) def add_license_header(file_path, comment_style): try: with open(file_path, 'r+', encoding='utf-8') as file: content = file.read() license_snippet = "Licensed under the Apache License, Version 2.0" if license_snippet not in content: header = "# Copyright Notice\n" if comment_style == "block": header = f"/* Copyright (c) {current_year} {owner}\n * Licensed under the Apache License, Version 2.0\n */\n\n" elif comment_style == "line": header = f"# Copyright (c) {current_year} {owner}\n# Licensed under the Apache License, Version 2.0\n\n" file.seek(0, 0) file.write(header + content) except FileNotFoundError: logging.error(f"File not found: {file_path}") file_map = { '.cpp': 'block', '.h': 'block', '.cu': 'block', '.py': 'line', '.cmake': 'line' } for root, dirs, files in os.walk("."): for file in files: ext = os.path.splitext(file)[1] if ext in file_map: add_license_header(os.path.join(root, file), file_map[ext]) elif 'CMakeLists.txt' in file: add_license_header(os.path.join(root, file), 'line') ================================================ FILE: script/clear-pycache.sh ================================================ find . | grep -E "(/__pycache__$|\.pyc$|\.pyo$)" | xargs rm -rf ================================================ FILE: setup.py ================================================ from setuptools import setup, find_packages with open('requirements.txt') as f: requirements = f.read().splitlines() setup( name='zo2', version='0.1.1', author='liangyuwang', author_email='liangyu.wang@kaust.edu.sa', description='ZO2 (Zeroth-Order Offloading), a framework for full parameter fine-tuning 175B LLMs with 18GB GPU memory', long_description=open('README.md').read(), long_description_content_type='text/markdown', packages=find_packages(), install_requires=requirements, # List of dependencies, read from requirements.txt classifiers=[ 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3.12', 'License :: OSI Approved :: Apache Software License', 'Operating System :: OS Independent', ], python_requires='>=3.11', include_package_data=True, zip_safe=False ) ================================================ FILE: test/README.md ================================================ # Test - Important Notice: For fine-tuning the **OPT-175B** model, ensure that your system is equipped with at least `18GB of GPU memory` and `600GB of CPU memory`. ## Example: MeZO-SGD on OPT Models ```shell # compare memory bash test/mezo_sgd/hf_opt/test_memory_train.sh ``` ```shell # compare throughput bash test/mezo_sgd/hf_opt/test_speed_train.sh ``` ```shell # compare accuracy bash test/mezo_sgd/hf_opt/test_acc_train.sh ``` ## Supported Tests In progress... ================================================ FILE: test/mezo_sgd/hf_gpt/trainer.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 ================================================ FILE: test/mezo_sgd/hf_llama/trainer.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 ================================================ FILE: test/mezo_sgd/hf_opt/record_zo2_memory.sh ================================================ #!/bin/bash set -e set -o pipefail model_names=("opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b") task_ids=("causalLM") # ANSI color codes GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' for model_name in "${model_names[@]}" do for task_id in "${task_ids[@]}" do echo "Testing model_name: $model_name, task_id: $task_id" CMD2="python test/mezo_sgd/hf_opt/test_memory.py --model_name $model_name --task $task_id --zo_method zo2 --max_steps 30" OUT2="/tmp/output2_${model_name}_${task_id}.txt" $CMD2 2>&1 | tee $OUT2 echo "Recording Peak GPU and CPU Memory usage..." max_mem1=$(grep 'Peak GPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1) max_mem2=$(grep 'Peak CPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1) if [ -z "$max_mem1" ] || [ -z "$max_mem2" ]; then echo "Could not find memory usage data in the output." else echo -e "Model: $model_name, Task: $task_id" echo -e "ZO2 peak GPU memory: ${GREEN}$max_mem1 MB${NC}" echo -e "ZO2 peak CPU memory: ${GREEN}$max_mem2 MB${NC}" fi rm $OUT2 done done ================================================ FILE: test/mezo_sgd/hf_opt/record_zo2_speed.sh ================================================ #!/bin/bash set -e set -o pipefail model_names=("opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b") task_ids=("causalLM") # ANSI color codes GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' for model_name in "${model_names[@]}" do for task_id in "${task_ids[@]}" do echo "Testing model_name: $model_name, task_id: $task_id" CMD2="python test/mezo_sgd/hf_opt/test_speed.py --model_name $model_name --task $task_id --zo_method zo2 --max_steps 30" OUT2="/tmp/output2_${model_name}_${task_id}.txt" $CMD2 2>&1 | tee $OUT2 echo "Recording throughput..." # Count the total number of lines and determine the number of iteration lines total_lines2=$(wc -l < $OUT2) iter_lines2=$(grep -c 'Time cost after iteration' $OUT2) # Calculate the starting line for the last 50% of iterations start_line2=$(($total_lines2 - $iter_lines2 + $(($iter_lines2 / 2 + 1)))) # Calculate average tokens per second for the last 50% of the iterations avg_tok_s2=$(tail -n +$start_line2 $OUT2 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}') echo -e "Model: $model_name, Task: $task_id" echo -e "ZO2 average throughput (last 50% iterations): ${GREEN}$avg_tok_s2 tok/s${NC}" rm $OUT2 done done ================================================ FILE: test/mezo_sgd/hf_opt/test_acc.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import sys sys.path.append("../zo2") import torch from tqdm import tqdm from zo2.config.mezo_sgd import MeZOSGDConfig from zo2.model.huggingface.opt.mezo_sgd import zo, zo2 from zo2.utils.utils import seed_everything from utils import ( OPTConfigs, prepare_data_for_causalLM, prepare_data_for_sequence_classification, prepare_data_for_question_answering, model_size, get_args ) def train_mezo_sgd_causalLM(model_config, zo_config, device='cuda'): input_ids, labels = prepare_data_for_causalLM( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo.OPTForCausalLM(model_config).to(device) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_eval() model.zo_train() loss = model(input_ids=input_ids, labels=labels) res = "Iteration {}, loss: {}, projected grad: {}" tqdm.write(res.format(i, loss, model.opt.projected_grad)) def train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'): input_ids, labels = prepare_data_for_causalLM( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo2.OPTForCausalLM(model_config) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_eval() model.zo_train() loss = model(input_ids=input_ids, labels=labels) res = "Iteration {}, loss: {}, projected grad: {}" tqdm.write(res.format(i, loss, model.opt.projected_grad)) def eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda'): input_ids, labels = prepare_data_for_causalLM( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo.OPTForCausalLM(model_config).to(device) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_train() model.zo_eval() loss = model(input_ids=input_ids, labels=labels)["loss"] res = "Iteration {}, loss: {}" tqdm.write(res.format(i, loss)) def eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'): input_ids, labels = prepare_data_for_causalLM( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo2.OPTForCausalLM(model_config) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_train() model.zo_eval() loss = model(input_ids=input_ids, labels=labels)["loss"] res = "Iteration {}, loss: {}" tqdm.write(res.format(i, loss)) def train_mezo_sgd_sequence_classification(model_config, zo_config, device='cuda'): input_ids, labels = prepare_data_for_sequence_classification( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo.OPTForSequenceClassification(model_config).to(device) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_eval() model.zo_train() loss = model(input_ids=input_ids, labels=labels) res = "Iteration {}, loss: {}, projected grad: {}" tqdm.write(res.format(i, loss, model.opt.projected_grad)) def train_mezo2_sgd_sequence_classification(model_config, zo_config, device='cuda'): input_ids, labels = prepare_data_for_sequence_classification( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo2.OPTForSequenceClassification(model_config).to(device) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_eval() model.zo_train() loss = model(input_ids=input_ids, labels=labels) res = "Iteration {}, loss: {}, projected grad: {}" tqdm.write(res.format(i, loss, model.opt.projected_grad)) def eval_mezo_sgd_sequence_classification(model_config, zo_config, device='cuda'): input_ids, labels = prepare_data_for_sequence_classification( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo.OPTForSequenceClassification(model_config).to(device) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_train() model.zo_eval() loss = model(input_ids=input_ids, labels=labels)["loss"] res = "Iteration {}, loss: {}" tqdm.write(res.format(i, loss)) def eval_mezo2_sgd_sequence_classification(model_config, zo_config, device='cuda'): input_ids, labels = prepare_data_for_sequence_classification( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo2.OPTForSequenceClassification(model_config).to(device) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_train() model.zo_eval() loss = model(input_ids=input_ids, labels=labels)["loss"] res = "Iteration {}, loss: {}" tqdm.write(res.format(i, loss)) def train_mezo_sgd_question_answering(model_config, zo_config, device='cuda'): input_ids, start_positions, end_positions = prepare_data_for_question_answering( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo.OPTForQuestionAnswering(model_config).to("cuda") model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_eval() model.zo_train() loss = model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions) res = "Iteration {}, loss: {}, projected grad: {}" tqdm.write(res.format(i, loss, model.opt.projected_grad)) def train_mezo2_sgd_question_answering(model_config, zo_config, device='cuda'): input_ids, start_positions, end_positions = prepare_data_for_question_answering( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo2.OPTForQuestionAnswering(model_config).to("cuda") model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_eval() model.zo_train() loss = model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions) res = "Iteration {}, loss: {}, projected grad: {}" tqdm.write(res.format(i, loss, model.opt.projected_grad)) def eval_mezo_sgd_question_answering(model_config, zo_config, device='cuda'): input_ids, start_positions, end_positions = prepare_data_for_question_answering( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo.OPTForQuestionAnswering(model_config).to("cuda") model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_train() model.zo_eval() loss = model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions)["loss"] res = "Iteration {}, loss: {}" tqdm.write(res.format(i, loss)) def eval_mezo2_sgd_question_answering(model_config, zo_config, device='cuda'): input_ids, start_positions, end_positions = prepare_data_for_question_answering( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo2.OPTForQuestionAnswering(model_config).to("cuda") model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_train() model.zo_eval() loss = model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions)["loss"] res = "Iteration {}, loss: {}" tqdm.write(res.format(i, loss)) def test_mezo_sgd_causalLM_training(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False train_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device) def test_mezo2_sgd_causalLM_training(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True zo_cfg.overlap = args.overlap=="all" train_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device) def test_mezo_sgd_causalLM_eval(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False eval_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device) def test_mezo2_sgd_causalLM_eval(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True zo_cfg.overlap = args.overlap=="all" eval_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device) def test_mezo_sgd_sequence_classification_training(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False train_mezo_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device) def test_mezo2_sgd_sequence_classification_training(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True zo_cfg.overlap = args.overlap=="all" train_mezo2_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device) def test_mezo_sgd_sequence_classification_eval(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False eval_mezo_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device) def test_mezo2_sgd_sequence_classification_eval(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True zo_cfg.overlap = args.overlap=="all" eval_mezo2_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device) def test_mezo_sgd_question_answering_training(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False train_mezo_sgd_question_answering(model_config, zo_cfg, device=args.working_device) def test_mezo2_sgd_question_answering_training(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True zo_cfg.overlap = args.overlap=="all" train_mezo2_sgd_question_answering(model_config, zo_cfg, device=args.working_device) def test_mezo_sgd_question_answering_eval(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False eval_mezo_sgd_question_answering(model_config, zo_cfg, device=args.working_device) def test_mezo2_sgd_question_answering_eval(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True zo_cfg.overlap = args.overlap=="all" eval_mezo2_sgd_question_answering(model_config, zo_cfg, device=args.working_device) if __name__=="__main__": args = get_args() original_dtype = torch.get_default_dtype() if args.zo_method == "zo": if args.task == "causalLM": if args.eval: test_mezo_sgd_causalLM_eval() else: test_mezo_sgd_causalLM_training() elif args.task == "sequence_classification": if args.eval: test_mezo_sgd_sequence_classification_eval() else: test_mezo_sgd_sequence_classification_training() elif args.task == "question_answering": if args.eval: test_mezo_sgd_question_answering_eval() else: test_mezo_sgd_question_answering_training() else: raise NotImplementedError(f"Task {args.task} is unsupported.") elif args.zo_method == "zo2": if args.task == "causalLM": if args.eval: test_mezo2_sgd_causalLM_eval() else: test_mezo2_sgd_causalLM_training() elif args.task == "sequence_classification": if args.eval: test_mezo2_sgd_sequence_classification_eval() else: test_mezo2_sgd_sequence_classification_training() elif args.task == "question_answering": if args.eval: test_mezo2_sgd_question_answering_eval() else: test_mezo2_sgd_question_answering_training() else: raise NotImplementedError(f"Task {args.task} is unsupported.") else: raise NotImplementedError ================================================ FILE: test/mezo_sgd/hf_opt/test_acc_eval.sh ================================================ #!/bin/bash set -e set -o pipefail model_names=("opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b") task_ids=("causalLM" "sequence_classification" "question_answering") # ANSI color codes GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' for model_name in "${model_names[@]}" do for task_id in "${task_ids[@]}" do echo "Testing model_name: $model_name, task_id: $task_id" if [ "$task_id" == "causalLM" ]; then lr=1e-4 else lr=1e-7 fi CMD1="python test/mezo_sgd/hf_opt/test_acc.py --model_name $model_name --task $task_id --zo_method zo --lr $lr --eval --max_steps 30" CMD2="python test/mezo_sgd/hf_opt/test_acc.py --model_name $model_name --task $task_id --zo_method zo2 --lr $lr --eval --max_steps 30" OUT1="/tmp/output1_${model_name}_${task_id}.txt" OUT2="/tmp/output2_${model_name}_${task_id}.txt" $CMD1 2>&1 | tee $OUT1 $CMD2 2>&1 | tee $OUT2 echo "Comparing outputs..." echo -e "Model: $model_name, Task: $task_id" paste <(grep 'Iteration' $OUT1) <(grep 'Iteration' $OUT2) | awk -v green="$GREEN" -v red="$RED" -v nc="$NC" '{ split($4, loss1, ","); split($8, loss2, ","); diff_loss = loss1[1] - loss2[1]; if (loss1[1] == loss2[1]) printf "Iteration %s: %s✓ loss match.%s\n", $2, green, nc; else printf "Iteration %s: %s✗ Mismatch! ZO (loss): (%s), ZO2 (loss): (%s) Loss diff: %.6f%s\n", $2, red, loss1[1], loss2[1], diff_loss, nc; }' rm $OUT1 $OUT2 done done ================================================ FILE: test/mezo_sgd/hf_opt/test_acc_train.sh ================================================ #!/bin/bash set -e set -o pipefail model_names=("opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b") task_ids=("causalLM" "sequence_classification" "question_answering") # ANSI color codes GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' for model_name in "${model_names[@]}" do for task_id in "${task_ids[@]}" do echo "Testing model_name: $model_name, task_id: $task_id" if [ "$task_id" == "causalLM" ]; then lr=1e-4 else lr=1e-7 fi CMD1="python test/mezo_sgd/hf_opt/test_acc.py --model_name $model_name --task $task_id --zo_method zo --lr $lr" CMD2="python test/mezo_sgd/hf_opt/test_acc.py --model_name $model_name --task $task_id --zo_method zo2 --lr $lr" OUT1="/tmp/output1_${model_name}_${task_id}.txt" OUT2="/tmp/output2_${model_name}_${task_id}.txt" $CMD1 2>&1 | tee $OUT1 $CMD2 2>&1 | tee $OUT2 echo "Comparing outputs..." echo -e "Model: $model_name, Task: $task_id" paste <(grep 'Iteration' $OUT1) <(grep 'Iteration' $OUT2) | awk -v green="$GREEN" -v red="$RED" -v nc="$NC" '{ split($4, loss1, ","); split($7, proj1, ","); split($11, loss2, ","); split($14, proj2, ","); diff_loss = loss1[1] - loss2[1]; diff_proj = proj1[1] - proj2[1]; if (loss1[1] == loss2[1] && proj1[1] == proj2[1]) printf "Iteration %d: %s✓ loss and projected grad match.%s\n", $2, green, nc; else printf "Iteration %d: %s✗ Mismatch! ZO (loss, grad): (%s, %s), ZO2 (loss, grad): (%s, %s)\n \tLoss diff: %.6f, Proj grad diff: %.6f%s\n", $2, red, loss1[1], proj1[1], loss2[1], proj2[1], diff_loss, diff_proj, nc; }' rm $OUT1 $OUT2 done done ================================================ FILE: test/mezo_sgd/hf_opt/test_memory.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import sys sys.path.append("../zo2") import torch from tqdm import tqdm from zo2.config.mezo_sgd import MeZOSGDConfig from zo2.model.huggingface.opt.mezo_sgd import zo, zo2 from zo2.utils.utils import seed_everything from utils import ( OPTConfigs, prepare_data_for_causalLM, prepare_data_for_sequence_classification, prepare_data_for_question_answering, model_size, get_args, check_peak_gpu_memory_usage, reset_peak_cpu_memory_usage, check_and_update_peak_cpu_memory_usage ) def train_mezo_sgd_causalLM(model_config, zo_config, device='cuda:0'): input_ids, labels = prepare_data_for_causalLM( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo.OPTForCausalLM(model_config).to(device) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) torch.cuda.reset_peak_memory_stats() reset_peak_cpu_memory_usage() for i in tqdm(range(args.max_steps)): model.zo_train() model(input_ids=input_ids, labels=labels) check_peak_gpu_memory_usage(i, int(device[-1]), True) check_and_update_peak_cpu_memory_usage(i, True) def train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda:0'): input_ids, labels = prepare_data_for_causalLM( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo2.OPTForCausalLM(model_config) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) torch.cuda.reset_peak_memory_stats() reset_peak_cpu_memory_usage() for i in tqdm(range(args.max_steps)): model.zo_train() model(input_ids=input_ids, labels=labels) check_peak_gpu_memory_usage(i, int(device[-1]), True) check_and_update_peak_cpu_memory_usage(i, True) def eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda:0'): input_ids, labels = prepare_data_for_causalLM( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo.OPTForCausalLM(model_config).to(device) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) torch.cuda.reset_peak_memory_stats() reset_peak_cpu_memory_usage() for i in tqdm(range(args.max_steps)): model.zo_eval() model(input_ids=input_ids, labels=labels) check_peak_gpu_memory_usage(i, int(device[-1]), True) check_and_update_peak_cpu_memory_usage(i, True) def eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda:0'): input_ids, labels = prepare_data_for_causalLM( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo2.OPTForCausalLM(model_config) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) torch.cuda.reset_peak_memory_stats() reset_peak_cpu_memory_usage() for i in tqdm(range(args.max_steps)): model.zo_eval() model(input_ids=input_ids, labels=labels) check_peak_gpu_memory_usage(i, int(device[-1]), True) check_and_update_peak_cpu_memory_usage(i, True) def train_mezo_sgd_sequence_classification(model_config, zo_config, device='cuda:0'): input_ids, labels = prepare_data_for_sequence_classification( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo.OPTForSequenceClassification(model_config).to(device) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) torch.cuda.reset_peak_memory_stats() reset_peak_cpu_memory_usage() for i in tqdm(range(args.max_steps)): model.zo_train() model(input_ids=input_ids, labels=labels) check_peak_gpu_memory_usage(i, int(device[-1]), True) check_and_update_peak_cpu_memory_usage(i, True) def train_mezo2_sgd_sequence_classification(model_config, zo_config, device='cuda:0'): input_ids, labels = prepare_data_for_sequence_classification( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo2.OPTForSequenceClassification(model_config).to(device) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) torch.cuda.reset_peak_memory_stats() reset_peak_cpu_memory_usage() for i in tqdm(range(args.max_steps)): model.zo_train() model(input_ids=input_ids, labels=labels) check_peak_gpu_memory_usage(i, int(device[-1]), True) check_and_update_peak_cpu_memory_usage(i, True) def eval_mezo_sgd_sequence_classification(model_config, zo_config, device='cuda:0'): input_ids, labels = prepare_data_for_sequence_classification( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo.OPTForSequenceClassification(model_config).to(device) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) torch.cuda.reset_peak_memory_stats() reset_peak_cpu_memory_usage() for i in tqdm(range(args.max_steps)): model.zo_eval() model(input_ids=input_ids, labels=labels) check_peak_gpu_memory_usage(i, int(device[-1]), True) check_and_update_peak_cpu_memory_usage(i, True) def eval_mezo2_sgd_sequence_classification(model_config, zo_config, device='cuda:0'): input_ids, labels = prepare_data_for_sequence_classification( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo2.OPTForSequenceClassification(model_config).to(device) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) torch.cuda.reset_peak_memory_stats() reset_peak_cpu_memory_usage() for i in tqdm(range(args.max_steps)): model.zo_eval() model(input_ids=input_ids, labels=labels) check_peak_gpu_memory_usage(i, int(device[-1]), True) check_and_update_peak_cpu_memory_usage(i, True) def train_mezo_sgd_question_answering(model_config, zo_config, device='cuda:0'): input_ids, start_positions, end_positions = prepare_data_for_question_answering( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo.OPTForQuestionAnswering(model_config).to("cuda") model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) torch.cuda.reset_peak_memory_stats() reset_peak_cpu_memory_usage() for i in tqdm(range(args.max_steps)): model.zo_train() model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions) check_peak_gpu_memory_usage(i, int(device[-1]), True) check_and_update_peak_cpu_memory_usage(i, True) def train_mezo2_sgd_question_answering(model_config, zo_config, device='cuda:0'): input_ids, start_positions, end_positions = prepare_data_for_question_answering( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo2.OPTForQuestionAnswering(model_config).to("cuda") model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) torch.cuda.reset_peak_memory_stats() reset_peak_cpu_memory_usage() for i in tqdm(range(args.max_steps)): model.zo_train() model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions) check_peak_gpu_memory_usage(i, int(device[-1]), True) check_and_update_peak_cpu_memory_usage(i, True) def eval_mezo_sgd_question_answering(model_config, zo_config, device='cuda:0'): input_ids, start_positions, end_positions = prepare_data_for_question_answering( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo.OPTForQuestionAnswering(model_config).to("cuda") model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) torch.cuda.reset_peak_memory_stats() reset_peak_cpu_memory_usage() for i in tqdm(range(args.max_steps)): model.zo_eval() model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions) check_peak_gpu_memory_usage(i, int(device[-1]), True) check_and_update_peak_cpu_memory_usage(i, True) def eval_mezo2_sgd_question_answering(model_config, zo_config, device='cuda:0'): input_ids, start_positions, end_positions = prepare_data_for_question_answering( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo2.OPTForQuestionAnswering(model_config).to("cuda") model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) torch.cuda.reset_peak_memory_stats() reset_peak_cpu_memory_usage() for i in tqdm(range(args.max_steps)): model.zo_eval() model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions) check_peak_gpu_memory_usage(i, int(device[-1]), True) check_and_update_peak_cpu_memory_usage(i, True) def test_mezo_sgd_causalLM_training(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False train_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device) def test_mezo2_sgd_causalLM_training(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True train_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device) def test_mezo_sgd_causalLM_eval(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False eval_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device) def test_mezo2_sgd_causalLM_eval(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True eval_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device) def test_mezo_sgd_sequence_classification_training(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False train_mezo_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device) def test_mezo2_sgd_sequence_classification_training(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True train_mezo2_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device) def test_mezo_sgd_sequence_classification_eval(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False eval_mezo_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device) def test_mezo2_sgd_sequence_classification_eval(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True eval_mezo2_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device) def test_mezo_sgd_question_answering_training(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False train_mezo_sgd_question_answering(model_config, zo_cfg, device=args.working_device) def test_mezo2_sgd_question_answering_training(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True train_mezo2_sgd_question_answering(model_config, zo_cfg, device=args.working_device) def test_mezo_sgd_question_answering_eval(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False eval_mezo_sgd_question_answering(model_config, zo_cfg, device=args.working_device) def test_mezo2_sgd_question_answering_eval(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True eval_mezo2_sgd_question_answering(model_config, zo_cfg, device=args.working_device) if __name__=="__main__": args = get_args() original_dtype = torch.get_default_dtype() if args.zo_method == "zo": if args.task == "causalLM": if args.eval: test_mezo_sgd_causalLM_eval() else: test_mezo_sgd_causalLM_training() elif args.task == "sequence_classification": if args.eval: test_mezo_sgd_sequence_classification_eval() else: test_mezo_sgd_sequence_classification_training() elif args.task == "question_answering": if args.eval: test_mezo_sgd_question_answering_eval() else: test_mezo_sgd_question_answering_training() else: raise NotImplementedError(f"Task {args.task} is unsupported.") elif args.zo_method == "zo2": if args.task == "causalLM": if args.eval: test_mezo2_sgd_causalLM_eval() else: test_mezo2_sgd_causalLM_training() elif args.task == "sequence_classification": if args.eval: test_mezo2_sgd_sequence_classification_eval() else: test_mezo2_sgd_sequence_classification_training() elif args.task == "question_answering": if args.eval: test_mezo2_sgd_question_answering_eval() else: test_mezo2_sgd_question_answering_training() else: raise NotImplementedError(f"Task {args.task} is unsupported.") else: raise NotImplementedError ================================================ FILE: test/mezo_sgd/hf_opt/test_memory_eval.sh ================================================ #!/bin/bash set -e set -o pipefail model_names=("opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b") task_ids=("causalLM") # ANSI color codes GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' for model_name in "${model_names[@]}" do for task_id in "${task_ids[@]}" do echo "Testing model_name: $model_name, task_id: $task_id" CMD1="python test/mezo_sgd/hf_opt/test_memory.py --model_name $model_name --task $task_id --zo_method zo --max_steps 30 --eval" CMD2="python test/mezo_sgd/hf_opt/test_memory.py --model_name $model_name --task $task_id --zo_method zo2 --max_steps 30 --eval" OUT1="/tmp/output1_${model_name}_${task_id}.txt" OUT2="/tmp/output2_${model_name}_${task_id}.txt" $CMD1 2>&1 | tee $OUT1 $CMD2 2>&1 | tee $OUT2 echo "Analyzing Peak GPU Memory usage..." max_mem1=$(grep 'Peak GPU Memory' $OUT1 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1) max_mem2=$(grep 'Peak GPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1) if [ -z "$max_mem1" ] || [ -z "$max_mem2" ]; then echo "Could not find memory usage data in the output." else ratio=$(echo "scale=2; $max_mem2 / $max_mem1 * 100" | bc) echo -e "Model: $model_name, Task: $task_id" echo -e "ZO peak GPU memory: ${GREEN}$max_mem1 MB${NC}" echo -e "ZO2 peak GPU memory: ${GREEN}$max_mem2 MB${NC}" echo -e "Memory usage ratio of ZO2 to ZO: ${GREEN}$ratio%${NC}" fi rm $OUT1 $OUT2 done done ================================================ FILE: test/mezo_sgd/hf_opt/test_memory_train.sh ================================================ #!/bin/bash set -e set -o pipefail model_names=("opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b") task_ids=("causalLM") # ANSI color codes GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' for model_name in "${model_names[@]}" do for task_id in "${task_ids[@]}" do echo "Testing model_name: $model_name, task_id: $task_id" CMD1="python test/mezo_sgd/hf_opt/test_memory.py --model_name $model_name --task $task_id --zo_method zo --max_steps 30" CMD2="python test/mezo_sgd/hf_opt/test_memory.py --model_name $model_name --task $task_id --zo_method zo2 --max_steps 30" OUT1="/tmp/output1_${model_name}_${task_id}.txt" OUT2="/tmp/output2_${model_name}_${task_id}.txt" $CMD1 2>&1 | tee $OUT1 $CMD2 2>&1 | tee $OUT2 echo "Analyzing Peak GPU Memory usage..." max_mem1=$(grep 'Peak GPU Memory' $OUT1 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1) max_mem2=$(grep 'Peak GPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1) if [ -z "$max_mem1" ] || [ -z "$max_mem2" ]; then echo "Could not find memory usage data in the output." else ratio=$(echo "scale=2; $max_mem2 / $max_mem1 * 100" | bc) echo -e "Model: $model_name, Task: $task_id" echo -e "ZO peak GPU memory: ${GREEN}$max_mem1 MB${NC}" echo -e "ZO2 peak GPU memory: ${GREEN}$max_mem2 MB${NC}" echo -e "Memory usage ratio of ZO2 to ZO: ${GREEN}$ratio%${NC}" fi rm $OUT1 $OUT2 done done ================================================ FILE: test/mezo_sgd/hf_opt/test_scheduler_acc_eval.sh ================================================ #!/bin/bash set -e set -o pipefail model_names=("opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b") task_ids=("causalLM" "sequence_classification" "question_answering") # ANSI color codes GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' for model_name in "${model_names[@]}" do for task_id in "${task_ids[@]}" do echo "Testing model_name: $model_name, task_id: $task_id" if [ "$task_id" == "causalLM" ]; then lr=1e-4 else lr=1e-7 fi CMD1="python test/mezo_sgd/hf_opt/test_acc.py --model_name $model_name --task $task_id --zo_method zo2 --lr $lr --eval --overlap no --max_steps 30" CMD2="python test/mezo_sgd/hf_opt/test_acc.py --model_name $model_name --task $task_id --zo_method zo2 --lr $lr --eval --max_steps 30" OUT1="/tmp/output1_${model_name}_${task_id}.txt" OUT2="/tmp/output2_${model_name}_${task_id}.txt" $CMD1 2>&1 | tee $OUT1 $CMD2 2>&1 | tee $OUT2 echo "Comparing outputs..." echo -e "Model: $model_name, Task: $task_id" paste <(grep 'Iteration' $OUT1) <(grep 'Iteration' $OUT2) | awk -v green="$GREEN" -v red="$RED" -v nc="$NC" '{ split($4, loss1, ","); split($8, loss2, ","); diff_loss = loss1[1] - loss2[1]; if (loss1[1] == loss2[1]) printf "Iteration %s: %s✓ loss match.%s\n", $2, green, nc; else printf "Iteration %s: %s✗ Mismatch! ZO (loss): (%s), ZO2 (loss): (%s) Loss diff: %.6f%s\n", $2, red, loss1[1], loss2[1], diff_loss, nc; }' rm $OUT1 $OUT2 done done ================================================ FILE: test/mezo_sgd/hf_opt/test_scheduler_acc_train.sh ================================================ #!/bin/bash set -e set -o pipefail model_names=("opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b") task_ids=("causalLM" "sequence_classification" "question_answering") # ANSI color codes GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' for model_name in "${model_names[@]}" do for task_id in "${task_ids[@]}" do echo "Testing model_name: $model_name, task_id: $task_id" if [ "$task_id" == "causalLM" ]; then lr=1e-4 else lr=1e-7 fi CMD1="python test/mezo_sgd/hf_opt/test_acc.py --model_name $model_name --task $task_id --zo_method zo2 --lr $lr --overlap no --max_steps 30" CMD2="python test/mezo_sgd/hf_opt/test_acc.py --model_name $model_name --task $task_id --zo_method zo2 --lr $lr --max_steps 30" OUT1="/tmp/output1_${model_name}_${task_id}.txt" OUT2="/tmp/output2_${model_name}_${task_id}.txt" $CMD1 2>&1 | tee $OUT1 $CMD2 2>&1 | tee $OUT2 echo "Comparing outputs..." echo -e "Model: $model_name, Task: $task_id" paste <(grep 'Iteration' $OUT1) <(grep 'Iteration' $OUT2) | awk -v green="$GREEN" -v red="$RED" -v nc="$NC" '{ split($4, loss1, ","); split($7, proj1, ","); split($11, loss2, ","); split($14, proj2, ","); diff_loss = loss1[1] - loss2[1]; diff_proj = proj1[1] - proj2[1]; if (loss1[1] == loss2[1] && proj1[1] == proj2[1]) printf "Iteration %d: %s✓ loss and projected grad match.%s\n", $2, green, nc; else printf "Iteration %d: %s✗ Mismatch! Non-Overlap (loss, grad): (%s, %s), Overlap (loss, grad): (%s, %s)\n \tLoss diff: %.6f, Proj grad diff: %.6f%s\n", $2, red, loss1[1], proj1[1], loss2[1], proj2[1], diff_loss, diff_proj, nc; }' rm $OUT1 $OUT2 done done ================================================ FILE: test/mezo_sgd/hf_opt/test_speed.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import sys sys.path.append("../zo2") import torch from tqdm import tqdm from zo2.config.mezo_sgd import MeZOSGDConfig from zo2.model.huggingface.opt.mezo_sgd import zo, zo2 from zo2.utils.utils import seed_everything from utils import ( OPTConfigs, prepare_data_for_causalLM, prepare_data_for_sequence_classification, prepare_data_for_question_answering, model_size, get_args, check_throughput ) def train_mezo_sgd_causalLM(model_config, zo_config, device='cuda'): input_ids, labels = prepare_data_for_causalLM( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo.OPTForCausalLM(model_config).to(device) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_train() check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True) def train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'): input_ids, labels = prepare_data_for_causalLM( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo2.OPTForCausalLM(model_config) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_train() check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True) def eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda'): input_ids, labels = prepare_data_for_causalLM( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo.OPTForCausalLM(model_config).to(device) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_eval() check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True) def eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'): input_ids, labels = prepare_data_for_causalLM( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo2.OPTForCausalLM(model_config) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_eval() check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True) def train_mezo_sgd_sequence_classification(model_config, zo_config, device='cuda'): input_ids, labels = prepare_data_for_sequence_classification( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo.OPTForSequenceClassification(model_config).to(device) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_train() check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True) def train_mezo2_sgd_sequence_classification(model_config, zo_config, device='cuda'): input_ids, labels = prepare_data_for_sequence_classification( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo2.OPTForSequenceClassification(model_config).to(device) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_train() check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True) def eval_mezo_sgd_sequence_classification(model_config, zo_config, device='cuda'): input_ids, labels = prepare_data_for_sequence_classification( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo.OPTForSequenceClassification(model_config).to(device) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_eval() check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True) def eval_mezo2_sgd_sequence_classification(model_config, zo_config, device='cuda'): input_ids, labels = prepare_data_for_sequence_classification( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo2.OPTForSequenceClassification(model_config).to(device) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_eval() check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True) def train_mezo_sgd_question_answering(model_config, zo_config, device='cuda'): input_ids, start_positions, end_positions = prepare_data_for_question_answering( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo.OPTForQuestionAnswering(model_config).to("cuda") model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_train() check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, start_positions=start_positions, end_positions=end_positions, use_tqdm=True) def train_mezo2_sgd_question_answering(model_config, zo_config, device='cuda'): input_ids, start_positions, end_positions = prepare_data_for_question_answering( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo2.OPTForQuestionAnswering(model_config).to("cuda") model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_train() check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, start_positions=start_positions, end_positions=end_positions, use_tqdm=True) def eval_mezo_sgd_question_answering(model_config, zo_config, device='cuda'): input_ids, start_positions, end_positions = prepare_data_for_question_answering( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo.OPTForQuestionAnswering(model_config).to("cuda") model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_eval() check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, start_positions=start_positions, end_positions=end_positions, use_tqdm=True) def eval_mezo2_sgd_question_answering(model_config, zo_config, device='cuda'): input_ids, start_positions, end_positions = prepare_data_for_question_answering( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo2.OPTForQuestionAnswering(model_config).to("cuda") model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_eval() check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, start_positions=start_positions, end_positions=end_positions, use_tqdm=True) def test_mezo_sgd_causalLM_training(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False train_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device) def test_mezo2_sgd_causalLM_training(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True train_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device) def test_mezo_sgd_causalLM_eval(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False eval_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device) def test_mezo2_sgd_causalLM_eval(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True eval_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device) def test_mezo_sgd_sequence_classification_training(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False train_mezo_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device) def test_mezo2_sgd_sequence_classification_training(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True train_mezo2_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device) def test_mezo_sgd_sequence_classification_eval(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False eval_mezo_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device) def test_mezo2_sgd_sequence_classification_eval(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True eval_mezo2_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device) def test_mezo_sgd_question_answering_training(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False train_mezo_sgd_question_answering(model_config, zo_cfg, device=args.working_device) def test_mezo2_sgd_question_answering_training(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True train_mezo2_sgd_question_answering(model_config, zo_cfg, device=args.working_device) def test_mezo_sgd_question_answering_eval(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False eval_mezo_sgd_question_answering(model_config, zo_cfg, device=args.working_device) def test_mezo2_sgd_question_answering_eval(): seed_everything(args.seed) model_configs = OPTConfigs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True eval_mezo2_sgd_question_answering(model_config, zo_cfg, device=args.working_device) if __name__=="__main__": args = get_args() original_dtype = torch.get_default_dtype() if args.zo_method == "zo": if args.task == "causalLM": if args.eval: test_mezo_sgd_causalLM_eval() else: test_mezo_sgd_causalLM_training() elif args.task == "sequence_classification": if args.eval: test_mezo_sgd_sequence_classification_eval() else: test_mezo_sgd_sequence_classification_training() elif args.task == "question_answering": if args.eval: test_mezo_sgd_question_answering_eval() else: test_mezo_sgd_question_answering_training() else: raise NotImplementedError(f"Task {args.task} is unsupported.") elif args.zo_method == "zo2": if args.task == "causalLM": if args.eval: test_mezo2_sgd_causalLM_eval() else: test_mezo2_sgd_causalLM_training() elif args.task == "sequence_classification": if args.eval: test_mezo2_sgd_sequence_classification_eval() else: test_mezo2_sgd_sequence_classification_training() elif args.task == "question_answering": if args.eval: test_mezo2_sgd_question_answering_eval() else: test_mezo2_sgd_question_answering_training() else: raise NotImplementedError(f"Task {args.task} is unsupported.") else: raise NotImplementedError ================================================ FILE: test/mezo_sgd/hf_opt/test_speed_eval.sh ================================================ #!/bin/bash set -e set -o pipefail model_names=("opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b") task_ids=("causalLM") # ANSI color codes GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' for model_name in "${model_names[@]}" do for task_id in "${task_ids[@]}" do echo "Testing model_name: $model_name, task_id: $task_id" CMD1="python test/mezo_sgd/hf_opt/test_speed.py --model_name $model_name --task $task_id --zo_method zo --max_steps 30 --eval" CMD2="python test/mezo_sgd/hf_opt/test_speed.py --model_name $model_name --task $task_id --zo_method zo2 --max_steps 30 --eval" OUT1="/tmp/output1_${model_name}_${task_id}.txt" OUT2="/tmp/output2_${model_name}_${task_id}.txt" $CMD1 2>&1 | tee $OUT1 $CMD2 2>&1 | tee $OUT2 echo "Analyzing throughput..." # Count the total number of lines and determine the number of iteration lines total_lines1=$(wc -l < $OUT1) total_lines2=$(wc -l < $OUT2) iter_lines1=$(grep -c 'Time cost after iteration' $OUT1) iter_lines2=$(grep -c 'Time cost after iteration' $OUT2) # Calculate the starting line for the last 50% of iterations start_line1=$(($total_lines1 - $iter_lines1 + $(($iter_lines1 / 2 + 1)))) start_line2=$(($total_lines2 - $iter_lines2 + $(($iter_lines2 / 2 + 1)))) # Calculate average tokens per second for the last 50% of the iterations avg_tok_s1=$(tail -n +$start_line1 $OUT1 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}') avg_tok_s2=$(tail -n +$start_line2 $OUT2 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}') ratio=$(echo "scale=2; $avg_tok_s2 / $avg_tok_s1 * 100" | bc) echo -e "Model: $model_name, Task: $task_id" echo -e "ZO average throughput (last 50% iterations): ${GREEN}$avg_tok_s1 tok/s${NC}" echo -e "ZO2 average throughput (last 50% iterations): ${GREEN}$avg_tok_s2 tok/s${NC}" echo -e "Throughput ratio of ZO2 to ZO (last 50% iterations): ${GREEN}$ratio%${NC}" rm $OUT1 $OUT2 done done ================================================ FILE: test/mezo_sgd/hf_opt/test_speed_train.sh ================================================ #!/bin/bash set -e set -o pipefail model_names=("opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b") task_ids=("causalLM") # ANSI color codes GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' for model_name in "${model_names[@]}" do for task_id in "${task_ids[@]}" do echo "Testing model_name: $model_name, task_id: $task_id" CMD1="python test/mezo_sgd/hf_opt/test_speed.py --model_name $model_name --task $task_id --zo_method zo --max_steps 30" CMD2="python test/mezo_sgd/hf_opt/test_speed.py --model_name $model_name --task $task_id --zo_method zo2 --max_steps 30" OUT1="/tmp/output1_${model_name}_${task_id}.txt" OUT2="/tmp/output2_${model_name}_${task_id}.txt" $CMD1 2>&1 | tee $OUT1 $CMD2 2>&1 | tee $OUT2 echo "Analyzing throughput..." # Count the total number of lines and determine the number of iteration lines total_lines1=$(wc -l < $OUT1) total_lines2=$(wc -l < $OUT2) iter_lines1=$(grep -c 'Time cost after iteration' $OUT1) iter_lines2=$(grep -c 'Time cost after iteration' $OUT2) # Calculate the starting line for the last 50% of iterations start_line1=$(($total_lines1 - $iter_lines1 + $(($iter_lines1 / 2 + 1)))) start_line2=$(($total_lines2 - $iter_lines2 + $(($iter_lines2 / 2 + 1)))) # Calculate average tokens per second for the last 50% of the iterations avg_tok_s1=$(tail -n +$start_line1 $OUT1 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}') avg_tok_s2=$(tail -n +$start_line2 $OUT2 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}') ratio=$(echo "scale=2; $avg_tok_s2 / $avg_tok_s1 * 100" | bc) echo -e "Model: $model_name, Task: $task_id" echo -e "ZO average throughput (last 50% iterations): ${GREEN}$avg_tok_s1 tok/s${NC}" echo -e "ZO2 average throughput (last 50% iterations): ${GREEN}$avg_tok_s2 tok/s${NC}" echo -e "Throughput ratio of ZO2 to ZO (last 50% iterations): ${GREEN}$ratio%${NC}" rm $OUT1 $OUT2 done done ================================================ FILE: test/mezo_sgd/hf_opt/utils.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import torch import time import argparse from tqdm import tqdm import psutil import os from transformers import OPTConfig import pynvml def get_args(): args = argparse.ArgumentParser() args.add_argument("--zo_method", type=str, default="zo2") args.add_argument("--eval", action="store_true") args.add_argument("--task", type=str, default="causalLM") args.add_argument("--model_name", type=str, default="opt_125m") args.add_argument("--model_dtype", type=str, default="fp16") args.add_argument("--verbose", action="store_true") args.add_argument("--max_steps", type=int, default=3) args.add_argument("--lr", type=float, default=1e-3) args.add_argument("--weight_decay", type=float, default=1e-1) args.add_argument("--zo_eps", type=float, default=1e-3) args.add_argument("--seed", type=int, default=42) args.add_argument("--batch_size", type=int, default=1) args.add_argument("--sequence_length", type=int, default=2048) args.add_argument("--overlap", type=str, default="all") args.add_argument("--offloading_device", type=str, default="cpu") args.add_argument("--working_device", type=str, default="cuda:0") args = args.parse_args() args.model_dtype = dtype_lookup[args.model_dtype] return args class OPTConfigs: opt_125m: OPTConfig = OPTConfig(num_hidden_layers=12, num_attention_heads=12, hidden_size=768, ffn_dim=3072, max_position_embeddings=2048) opt_350m: OPTConfig = OPTConfig(num_hidden_layers=24, num_attention_heads=16, hidden_size=1024, ffn_dim=4096, max_position_embeddings=2048) opt_1_3b: OPTConfig = OPTConfig(num_hidden_layers=24, num_attention_heads=32, hidden_size=2048, ffn_dim=8192, max_position_embeddings=2048) opt_2_7b: OPTConfig = OPTConfig(num_hidden_layers=32, num_attention_heads=32, hidden_size=2560, ffn_dim=10240, max_position_embeddings=2048) opt_6_7b: OPTConfig = OPTConfig(num_hidden_layers=32, num_attention_heads=32, hidden_size=4096, ffn_dim=16384, max_position_embeddings=2048) opt_13b: OPTConfig = OPTConfig(num_hidden_layers=40, num_attention_heads=40, hidden_size=5120, ffn_dim=20480, max_position_embeddings=2048) opt_30b: OPTConfig = OPTConfig(num_hidden_layers=48, num_attention_heads=56, hidden_size=7168, ffn_dim=28672, max_position_embeddings=2048) opt_66b: OPTConfig = OPTConfig(num_hidden_layers=64, num_attention_heads=72, hidden_size=9216, ffn_dim=36864, max_position_embeddings=2048) opt_175b: OPTConfig = OPTConfig(num_hidden_layers=96, num_attention_heads=96, hidden_size=12288, ffn_dim=49152, max_position_embeddings=2048) dtype_lookup = { "fp64": torch.float64, "fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16 } def model_size(model: torch.nn.Module): total_size = sum(p.numel() for p in model.parameters()) trainable_size = sum(p.numel() for p in model.parameters() if p.requires_grad) return {"total": total_size, "trainable": trainable_size} def prepare_data_for_causalLM(V, B, T, device='cuda'): data_batch = torch.randint(0, V, (B, T)).to(device) input_ids = data_batch labels = data_batch return input_ids, labels def prepare_data_for_sequence_classification(V, B, T, device='cuda'): input_ids = torch.randint(0, V, (B, T)).to(device) labels = torch.randint(0, 1, (B, )).to(device) return input_ids, labels def prepare_data_for_question_answering(V, B, T, device='cuda'): input_ids = torch.randint(0, V, (B, T)).to(device) start_positions = torch.randint(0, 1, (B, )).to(device) end_positions = torch.randint(1, 2, (B, )).to(device) return input_ids, start_positions, end_positions # GPU Memory Monitoring pynvml.nvmlInit() def check_peak_gpu_memory_usage(iter, device=0, use_tqdm=False): # Check the peak memory usage handle = pynvml.nvmlDeviceGetHandleByIndex(device) # Adjust index if necessary info = pynvml.nvmlDeviceGetMemoryInfo(handle) peak_memory = info.used / 1024**2 if use_tqdm: tqdm.write("Peak GPU Memory after iteration {}: {:.2f} MB".format(iter+1, peak_memory)) else: print(f"Peak GPU Memory after iteration {iter+1}: {peak_memory:.2f} MB") # CPU Memory Monitoring peak_memory_cpu = 0 def check_and_update_peak_cpu_memory_usage(iter, use_tqdm=False): global peak_memory_cpu process = psutil.Process(os.getpid()) current_memory = process.memory_info().rss / (1024 ** 2) # Convert to MB if current_memory > peak_memory_cpu: peak_memory_cpu = current_memory if use_tqdm: tqdm.write(f"Peak CPU Memory after iteration {iter+1}: {peak_memory_cpu:.2f} MB") else: print(f"Peak CPU Memory after iteration {iter+1}: {peak_memory_cpu:.2f} MB") def reset_peak_cpu_memory_usage(): global peak_memory_cpu peak_memory_cpu = 0 if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() def check_throughput(iter, total_token_batch_size_per_iter, fn, *args, use_tqdm=False, **kwargs): t1 = time.time() out = fn(*args, **kwargs) t2 = time.time() time_cost = t2-t1 throughtput = total_token_batch_size_per_iter / time_cost if use_tqdm: tqdm.write("Time cost after iteration {}: {:.2f} ms, {:.2f} tok/s".format(iter+1, time_cost*1e3, throughtput)) else: print("Time cost after iteration {}: {:.2f} ms, {:.2f} tok/s".format(iter+1, time_cost*1e3, throughtput)) ================================================ FILE: test/mezo_sgd/hf_qwen3/record_zo2_memory.sh ================================================ #!/bin/bash set -e set -o pipefail model_names=("qwen3_0_6b" "qwen3_1_7b" "qwen3_4b" "qwen3_8b" "qwen3_14b" "qwen3_32b") task_ids=("causalLM") # ANSI color codes GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' for model_name in "${model_names[@]}" do for task_id in "${task_ids[@]}" do echo "Testing model_name: $model_name, task_id: $task_id" CMD2="python test/mezo_sgd/hf_qwen3/test_memory.py --model_name $model_name --zo_method zo2 --max_steps 30" OUT2="/tmp/output2_${model_name}_${task_id}.txt" $CMD2 2>&1 | tee $OUT2 echo "Recording Peak GPU and CPU Memory usage..." max_mem1=$(grep 'Peak GPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1) max_mem2=$(grep 'Peak CPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1) if [ -z "$max_mem1" ] || [ -z "$max_mem2" ]; then echo "Could not find memory usage data in the output." else echo -e "Model: $model_name, Task: $task_id" echo -e "ZO2 peak GPU memory: ${GREEN}$max_mem1 MB${NC}" echo -e "ZO2 peak CPU memory: ${GREEN}$max_mem2 MB${NC}" fi rm $OUT2 done done ================================================ FILE: test/mezo_sgd/hf_qwen3/record_zo2_speed.sh ================================================ #!/bin/bash set -e set -o pipefail model_names=("qwen3_0_6b" "qwen3_1_7b" "qwen3_4b" "qwen3_8b" "qwen3_14b" "qwen3_32b") task_ids=("causalLM") # ANSI color codes GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' for model_name in "${model_names[@]}" do for task_id in "${task_ids[@]}" do echo "Testing model_name: $model_name, task_id: $task_id" CMD2="python test/mezo_sgd/hf_qwen3/test_speed.py --model_name $model_name --zo_method zo2 --max_steps 30" OUT2="/tmp/output2_${model_name}_${task_id}.txt" $CMD2 2>&1 | tee $OUT2 echo "Recording throughput..." # Count the total number of lines and determine the number of iteration lines total_lines2=$(wc -l < $OUT2) iter_lines2=$(grep -c 'Time cost after iteration' $OUT2) # Calculate the starting line for the last 50% of iterations start_line2=$(($total_lines2 - $iter_lines2 + $(($iter_lines2 / 2 + 1)))) # Calculate average tokens per second for the last 50% of the iterations avg_tok_s2=$(tail -n +$start_line2 $OUT2 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}') echo -e "Model: $model_name, Task: $task_id" echo -e "ZO2 average throughput (last 50% iterations): ${GREEN}$avg_tok_s2 tok/s${NC}" rm $OUT2 done done ================================================ FILE: test/mezo_sgd/hf_qwen3/test_acc.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import sys sys.path.append("../zo2") import torch from tqdm import tqdm from zo2.config.mezo_sgd import MeZOSGDConfig from zo2.model.huggingface.qwen3.mezo_sgd import zo, zo2 from zo2.utils.utils import seed_everything from utils import ( Qwen3Configs, prepare_data_for_causalLM, model_size, get_args ) def train_mezo_sgd_causalLM(model_config, zo_config, device='cuda'): input_ids, labels = prepare_data_for_causalLM( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo.Qwen3ForCausalLM(model_config).to(device) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_eval() model.zo_train() loss = model(input_ids=input_ids, labels=labels) res = "Iteration {}, loss: {}, projected grad: {}" tqdm.write(res.format(i, loss, model.opt.projected_grad)) def train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'): input_ids, labels = prepare_data_for_causalLM( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo2.Qwen3ForCausalLM(model_config) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_eval() model.zo_train() loss = model(input_ids=input_ids, labels=labels) res = "Iteration {}, loss: {}, projected grad: {}" tqdm.write(res.format(i, loss, model.opt.projected_grad)) def eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda'): input_ids, labels = prepare_data_for_causalLM( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo.Qwen3ForCausalLM(model_config).to(device) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_train() model.zo_eval() loss = model(input_ids=input_ids, labels=labels)["loss"] res = "Iteration {}, loss: {}" tqdm.write(res.format(i, loss)) def eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'): input_ids, labels = prepare_data_for_causalLM( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo2.Qwen3ForCausalLM(model_config) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_train() model.zo_eval() loss = model(input_ids=input_ids, labels=labels)["loss"] res = "Iteration {}, loss: {}" tqdm.write(res.format(i, loss)) def test_mezo_sgd_causalLM_training(): seed_everything(args.seed) model_configs = Qwen3Configs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length # model_config._attn_implementation = "eager" zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False train_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device) def test_mezo2_sgd_causalLM_training(): seed_everything(args.seed) model_configs = Qwen3Configs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length # model_config._attn_implementation = "eager" zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True zo_cfg.overlap = args.overlap=="all" train_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device) def test_mezo_sgd_causalLM_eval(): seed_everything(args.seed) model_configs = Qwen3Configs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False eval_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device) def test_mezo2_sgd_causalLM_eval(): seed_everything(args.seed) model_configs = Qwen3Configs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True zo_cfg.overlap = args.overlap=="all" eval_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device) if __name__=="__main__": args = get_args() # torch.set_printoptions(precision=10) original_dtype = torch.get_default_dtype() if args.zo_method == "zo": if args.eval: test_mezo_sgd_causalLM_eval() else: test_mezo_sgd_causalLM_training() elif args.zo_method == "zo2": if args.eval: test_mezo2_sgd_causalLM_eval() else: test_mezo2_sgd_causalLM_training() else: raise NotImplementedError ================================================ FILE: test/mezo_sgd/hf_qwen3/test_acc_eval.sh ================================================ #!/bin/bash set -e set -o pipefail model_names=("qwen3_0_6b" "qwen3_1_7b" "qwen3_4b" "qwen3_8b" "qwen3_14b" "qwen3_32b") task_ids=("causalLM") # ANSI color codes GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' for model_name in "${model_names[@]}" do for task_id in "${task_ids[@]}" do echo "Testing model_name: $model_name, task_id: $task_id" if [ "$task_id" == "causalLM" ]; then lr=1e-4 else lr=1e-7 fi CMD1="python test/mezo_sgd/hf_qwen3/test_acc.py --model_name $model_name --zo_method zo --lr $lr --eval --max_steps 30" CMD2="python test/mezo_sgd/hf_qwen3/test_acc.py --model_name $model_name --zo_method zo2 --lr $lr --eval --max_steps 30" OUT1="/tmp/output1_${model_name}_${task_id}.txt" OUT2="/tmp/output2_${model_name}_${task_id}.txt" $CMD1 2>&1 | tee $OUT1 $CMD2 2>&1 | tee $OUT2 echo "Comparing outputs..." echo -e "Model: $model_name, Task: $task_id" paste <(grep 'Iteration' $OUT1) <(grep 'Iteration' $OUT2) | awk -v green="$GREEN" -v red="$RED" -v nc="$NC" '{ split($4, loss1, ","); split($8, loss2, ","); diff_loss = loss1[1] - loss2[1]; if (loss1[1] == loss2[1]) printf "Iteration %s: %s✓ loss match.%s\n", $2, green, nc; else printf "Iteration %s: %s✗ Mismatch! ZO (loss): (%s), ZO2 (loss): (%s) Loss diff: %.6f%s\n", $2, red, loss1[1], loss2[1], diff_loss, nc; }' rm $OUT1 $OUT2 done done ================================================ FILE: test/mezo_sgd/hf_qwen3/test_acc_train.sh ================================================ #!/bin/bash set -e set -o pipefail model_names=("qwen3_0_6b" "qwen3_1_7b" "qwen3_4b" "qwen3_8b" "qwen3_14b" "qwen3_32b") task_ids=("causalLM") # ANSI color codes GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' for model_name in "${model_names[@]}" do for task_id in "${task_ids[@]}" do echo "Testing model_name: $model_name" if [ "$task_id" == "causalLM" ]; then lr=1e-4 else lr=1e-7 fi CMD1="python test/mezo_sgd/hf_qwen3/test_acc.py --model_name $model_name --zo_method zo --lr $lr" CMD2="python test/mezo_sgd/hf_qwen3/test_acc.py --model_name $model_name --zo_method zo2 --lr $lr" OUT1="/tmp/output1_${model_name}_${task_id}.txt" OUT2="/tmp/output2_${model_name}_${task_id}.txt" $CMD1 2>&1 | tee $OUT1 $CMD2 2>&1 | tee $OUT2 echo "Comparing outputs..." echo -e "Model: $model_name, Task: $task_id" paste <(grep 'Iteration' $OUT1) <(grep 'Iteration' $OUT2) | awk -v green="$GREEN" -v red="$RED" -v nc="$NC" '{ split($4, loss1, ","); split($7, proj1, ","); split($11, loss2, ","); split($14, proj2, ","); diff_loss = loss1[1] - loss2[1]; diff_proj = proj1[1] - proj2[1]; if (loss1[1] == loss2[1] && proj1[1] == proj2[1]) printf "Iteration %d: %s✓ loss and projected grad match.%s\n", $2, green, nc; else printf "Iteration %d: %s✗ Mismatch! ZO (loss, grad): (%s, %s), ZO2 (loss, grad): (%s, %s)\n \tLoss diff: %.6f, Proj grad diff: %.6f%s\n", $2, red, loss1[1], proj1[1], loss2[1], proj2[1], diff_loss, diff_proj, nc; }' rm $OUT1 $OUT2 done done ================================================ FILE: test/mezo_sgd/hf_qwen3/test_memory.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import sys sys.path.append("../zo2") import torch from tqdm import tqdm from zo2.config.mezo_sgd import MeZOSGDConfig from zo2.model.huggingface.qwen3.mezo_sgd import zo, zo2 from zo2.utils.utils import seed_everything from utils import ( Qwen3Configs, prepare_data_for_causalLM, model_size, get_args, reset_peak_cpu_memory_usage, check_peak_gpu_memory_usage, check_and_update_peak_cpu_memory_usage, ) def train_mezo_sgd_causalLM(model_config, zo_config, device='cuda:0'): input_ids, labels = prepare_data_for_causalLM( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo.Qwen3ForCausalLM(model_config).to(device) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) torch.cuda.reset_peak_memory_stats() reset_peak_cpu_memory_usage() for i in tqdm(range(args.max_steps)): model.zo_train() model(input_ids=input_ids, labels=labels) check_peak_gpu_memory_usage(i, int(device[-1]), True) check_and_update_peak_cpu_memory_usage(i, True) def train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda:0'): input_ids, labels = prepare_data_for_causalLM( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo2.Qwen3ForCausalLM(model_config) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) torch.cuda.reset_peak_memory_stats() reset_peak_cpu_memory_usage() for i in tqdm(range(args.max_steps)): model.zo_train() model(input_ids=input_ids, labels=labels) check_peak_gpu_memory_usage(i, int(device[-1]), True) check_and_update_peak_cpu_memory_usage(i, True) def eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda:0'): input_ids, labels = prepare_data_for_causalLM( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo.Qwen3ForCausalLM(model_config).to(device) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) torch.cuda.reset_peak_memory_stats() reset_peak_cpu_memory_usage() for i in tqdm(range(args.max_steps)): model.zo_eval() model(input_ids=input_ids, labels=labels) check_peak_gpu_memory_usage(i, int(device[-1]), True) check_and_update_peak_cpu_memory_usage(i, True) def eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda:0'): input_ids, labels = prepare_data_for_causalLM( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo2.Qwen3ForCausalLM(model_config) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) torch.cuda.reset_peak_memory_stats() reset_peak_cpu_memory_usage() for i in tqdm(range(args.max_steps)): model.zo_eval() model(input_ids=input_ids, labels=labels) check_peak_gpu_memory_usage(i, int(device[-1]), True) check_and_update_peak_cpu_memory_usage(i, True) def test_mezo_sgd_causalLM_training(): seed_everything(args.seed) model_configs = Qwen3Configs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False train_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device) def test_mezo2_sgd_causalLM_training(): seed_everything(args.seed) model_configs = Qwen3Configs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True train_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device) def test_mezo_sgd_causalLM_eval(): seed_everything(args.seed) model_configs = Qwen3Configs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False eval_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device) def test_mezo2_sgd_causalLM_eval(): seed_everything(args.seed) model_configs = Qwen3Configs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True eval_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device) if __name__=="__main__": args = get_args() original_dtype = torch.get_default_dtype() if args.zo_method == "zo": if args.eval: test_mezo_sgd_causalLM_eval() else: test_mezo_sgd_causalLM_training() elif args.zo_method == "zo2": if args.eval: test_mezo2_sgd_causalLM_eval() else: test_mezo2_sgd_causalLM_training() else: raise NotImplementedError ================================================ FILE: test/mezo_sgd/hf_qwen3/test_memory_train.sh ================================================ #!/bin/bash set -e set -o pipefail model_names=("qwen3_0_6b" "qwen3_1_7b" "qwen3_4b" "qwen3_8b" "qwen3_14b" "qwen3_32b") task_ids=("causalLM") # ANSI color codes GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' for model_name in "${model_names[@]}" do for task_id in "${task_ids[@]}" do echo "Testing model_name: $model_name, task_id: $task_id" CMD1="python test/mezo_sgd/hf_qwen3/test_memory.py --model_name $model_name --zo_method zo --max_steps 30" CMD2="python test/mezo_sgd/hf_qwen3/test_memory.py --model_name $model_name --zo_method zo2 --max_steps 30" OUT1="/tmp/output1_${model_name}_${task_id}.txt" OUT2="/tmp/output2_${model_name}_${task_id}.txt" $CMD1 2>&1 | tee $OUT1 $CMD2 2>&1 | tee $OUT2 echo "Analyzing Peak GPU Memory usage..." max_mem1=$(grep 'Peak GPU Memory' $OUT1 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1) max_mem2=$(grep 'Peak GPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1) if [ -z "$max_mem1" ] || [ -z "$max_mem2" ]; then echo "Could not find memory usage data in the output." else ratio=$(echo "scale=2; $max_mem2 / $max_mem1 * 100" | bc) echo -e "Model: $model_name, Task: $task_id" echo -e "ZO peak GPU memory: ${GREEN}$max_mem1 MB${NC}" echo -e "ZO2 peak GPU memory: ${GREEN}$max_mem2 MB${NC}" echo -e "Memory usage ratio of ZO2 to ZO: ${GREEN}$ratio%${NC}" fi rm $OUT1 $OUT2 done done ================================================ FILE: test/mezo_sgd/hf_qwen3/test_speed.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import sys sys.path.append("../zo2") import torch from tqdm import tqdm from zo2.config.mezo_sgd import MeZOSGDConfig from zo2.model.huggingface.qwen3.mezo_sgd import zo, zo2 from zo2.utils.utils import seed_everything from utils import ( Qwen3Configs, prepare_data_for_causalLM, model_size, get_args, check_throughput ) def train_mezo_sgd_causalLM(model_config, zo_config, device='cuda'): input_ids, labels = prepare_data_for_causalLM( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo.Qwen3ForCausalLM(model_config).to(device) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_train() check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True) def train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'): input_ids, labels = prepare_data_for_causalLM( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo2.Qwen3ForCausalLM(model_config) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_train() check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True) def eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda'): input_ids, labels = prepare_data_for_causalLM( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo.Qwen3ForCausalLM(model_config).to(device) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_eval() check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True) def eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'): input_ids, labels = prepare_data_for_causalLM( model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device) torch.set_default_dtype(args.model_dtype) model = zo2.Qwen3ForCausalLM(model_config) model.zo_init(zo_config) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") torch.set_default_dtype(original_dtype) for i in tqdm(range(args.max_steps)): model.zo_eval() check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True) def test_mezo_sgd_causalLM_training(): seed_everything(args.seed) model_configs = Qwen3Configs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False train_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device) def test_mezo2_sgd_causalLM_training(): seed_everything(args.seed) model_configs = Qwen3Configs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True train_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device) def test_mezo_sgd_causalLM_eval(): seed_everything(args.seed) model_configs = Qwen3Configs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False eval_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device) def test_mezo2_sgd_causalLM_eval(): seed_everything(args.seed) model_configs = Qwen3Configs() model_config = getattr(model_configs, args.model_name) model_config.tie_word_embeddings=False model_config.max_position_embeddings = args.sequence_length zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True eval_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device) if __name__=="__main__": args = get_args() original_dtype = torch.get_default_dtype() if args.zo_method == "zo": if args.eval: test_mezo_sgd_causalLM_eval() else: test_mezo_sgd_causalLM_training() elif args.zo_method == "zo2": if args.eval: test_mezo2_sgd_causalLM_eval() else: test_mezo2_sgd_causalLM_training() else: raise NotImplementedError ================================================ FILE: test/mezo_sgd/hf_qwen3/test_speed_train.sh ================================================ #!/bin/bash set -e set -o pipefail model_names=("qwen3_0_6b" "qwen3_1_7b" "qwen3_4b" "qwen3_8b" "qwen3_14b" "qwen3_32b") task_ids=("causalLM") # ANSI color codes GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' for model_name in "${model_names[@]}" do for task_id in "${task_ids[@]}" do echo "Testing model_name: $model_name, task_id: $task_id" CMD1="python test/mezo_sgd/hf_qwen3/test_speed.py --model_name $model_name --zo_method zo --max_steps 30" CMD2="python test/mezo_sgd/hf_qwen3/test_speed.py --model_name $model_name --zo_method zo2 --max_steps 30" OUT1="/tmp/output1_${model_name}_${task_id}.txt" OUT2="/tmp/output2_${model_name}_${task_id}.txt" $CMD1 2>&1 | tee $OUT1 $CMD2 2>&1 | tee $OUT2 echo "Analyzing throughput..." # Count the total number of lines and determine the number of iteration lines total_lines1=$(wc -l < $OUT1) total_lines2=$(wc -l < $OUT2) iter_lines1=$(grep -c 'Time cost after iteration' $OUT1) iter_lines2=$(grep -c 'Time cost after iteration' $OUT2) # Calculate the starting line for the last 50% of iterations start_line1=$(($total_lines1 - $iter_lines1 + $(($iter_lines1 / 2 + 1)))) start_line2=$(($total_lines2 - $iter_lines2 + $(($iter_lines2 / 2 + 1)))) # Calculate average tokens per second for the last 50% of the iterations avg_tok_s1=$(tail -n +$start_line1 $OUT1 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}') avg_tok_s2=$(tail -n +$start_line2 $OUT2 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}') ratio=$(echo "scale=2; $avg_tok_s2 / $avg_tok_s1 * 100" | bc) echo -e "Model: $model_name, Task: $task_id" echo -e "ZO average throughput (last 50% iterations): ${GREEN}$avg_tok_s1 tok/s${NC}" echo -e "ZO2 average throughput (last 50% iterations): ${GREEN}$avg_tok_s2 tok/s${NC}" echo -e "Throughput ratio of ZO2 to ZO (last 50% iterations): ${GREEN}$ratio%${NC}" rm $OUT1 $OUT2 done done ================================================ FILE: test/mezo_sgd/hf_qwen3/utils.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import torch import time import argparse from tqdm import tqdm import psutil import os from transformers import Qwen3Config import pynvml def get_args(): args = argparse.ArgumentParser() args.add_argument("--zo_method", type=str, default="zo2") args.add_argument("--eval", action="store_true") args.add_argument("--model_name", type=str, default="qwen3_0_6b") args.add_argument("--model_dtype", type=str, default="fp16") args.add_argument("--verbose", action="store_true") args.add_argument("--max_steps", type=int, default=3) args.add_argument("--lr", type=float, default=1e-3) args.add_argument("--weight_decay", type=float, default=1e-1) args.add_argument("--zo_eps", type=float, default=1e-3) args.add_argument("--seed", type=int, default=42) args.add_argument("--batch_size", type=int, default=1) args.add_argument("--sequence_length", type=int, default=2048) args.add_argument("--overlap", type=str, default="all") args.add_argument("--offloading_device", type=str, default="cpu") args.add_argument("--working_device", type=str, default="cuda:0") args = args.parse_args() args.model_dtype = dtype_lookup[args.model_dtype] return args class Qwen3Configs: qwen3_0_6b: Qwen3Config = Qwen3Config(num_hidden_layers=28, num_attention_heads=16, num_key_value_heads=8, max_window_layers=28, hidden_size=1024, intermediate_size=3072, max_position_embeddings=40960, use_sliding_window=False) qwen3_1_7b: Qwen3Config = Qwen3Config(num_hidden_layers=28, num_attention_heads=16, num_key_value_heads=8, max_window_layers=28, hidden_size=2048, intermediate_size=6144, max_position_embeddings=40960, use_sliding_window=False) qwen3_4b: Qwen3Config = Qwen3Config(num_hidden_layers=36, num_attention_heads=32, num_key_value_heads=8, max_window_layers=36, hidden_size=2560, intermediate_size=9728, max_position_embeddings=40960, use_sliding_window=False) qwen3_8b: Qwen3Config = Qwen3Config(num_hidden_layers=36, num_attention_heads=32, num_key_value_heads=8, max_window_layers=36, hidden_size=4096, intermediate_size=12288, max_position_embeddings=40960, use_sliding_window=False) qwen3_14b: Qwen3Config = Qwen3Config(num_hidden_layers=40, num_attention_heads=40, num_key_value_heads=8, max_window_layers=40, hidden_size=5120, intermediate_size=17408, max_position_embeddings=40960, use_sliding_window=False) qwen3_32b: Qwen3Config = Qwen3Config(num_hidden_layers=64, num_attention_heads=64, num_key_value_heads=8, max_window_layers=64, hidden_size=5120, intermediate_size=25600, max_position_embeddings=40960, use_sliding_window=False) dtype_lookup = { "fp64": torch.float64, "fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16 } def model_size(model: torch.nn.Module): total_size = sum(p.numel() for p in model.parameters()) trainable_size = sum(p.numel() for p in model.parameters() if p.requires_grad) return {"total": total_size, "trainable": trainable_size} def prepare_data_for_causalLM(V, B, T, device='cuda'): data_batch = torch.randint(0, V, (B, T)).to(device) input_ids = data_batch labels = data_batch return input_ids, labels # GPU Memory Monitoring pynvml.nvmlInit() def check_peak_gpu_memory_usage(iter, device=0, use_tqdm=False): # Check the peak memory usage handle = pynvml.nvmlDeviceGetHandleByIndex(device) # Adjust index if necessary info = pynvml.nvmlDeviceGetMemoryInfo(handle) peak_memory = info.used / 1024**2 if use_tqdm: tqdm.write("Peak GPU Memory after iteration {}: {:.2f} MB".format(iter+1, peak_memory)) else: print(f"Peak GPU Memory after iteration {iter+1}: {peak_memory:.2f} MB") # CPU Memory Monitoring peak_memory_cpu = 0 def check_and_update_peak_cpu_memory_usage(iter, use_tqdm=False): global peak_memory_cpu process = psutil.Process(os.getpid()) current_memory = process.memory_info().rss / (1024 ** 2) # Convert to MB if current_memory > peak_memory_cpu: peak_memory_cpu = current_memory if use_tqdm: tqdm.write(f"Peak CPU Memory after iteration {iter+1}: {peak_memory_cpu:.2f} MB") else: print(f"Peak CPU Memory after iteration {iter+1}: {peak_memory_cpu:.2f} MB") def reset_peak_cpu_memory_usage(): global peak_memory_cpu peak_memory_cpu = 0 if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() def check_throughput(iter, total_token_batch_size_per_iter, fn, *args, use_tqdm=False, **kwargs): t1 = time.time() out = fn(*args, **kwargs) t2 = time.time() time_cost = t2-t1 throughtput = total_token_batch_size_per_iter / time_cost if use_tqdm: tqdm.write("Time cost after iteration {}: {:.2f} ms, {:.2f} tok/s".format(iter+1, time_cost*1e3, throughtput)) else: print("Time cost after iteration {}: {:.2f} ms, {:.2f} tok/s".format(iter+1, time_cost*1e3, throughtput)) ================================================ FILE: test/mezo_sgd/nanogpt/record_zo2_memory.sh ================================================ #!/bin/bash set -e set -o pipefail model_ids=("gpt2" "gpt2_medium" "gpt2_large" "gpt2_xl" "opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b") # ANSI color codes GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' for model_id in "${model_ids[@]}" do echo "Testing model_id: $model_id" CMD2="python test/mezo_sgd/nanogpt/test_memory.py --model_id $model_id --zo_method zo2 --max_steps 30" OUT2="/tmp/output2_$model_id.txt" $CMD2 2>&1 | tee $OUT2 echo "Analyzing Peak GPU and CPU Memory usage..." max_mem1=$(grep 'Peak GPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1) max_mem2=$(grep 'Peak CPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1) if [ -z "$max_mem1" ] || [ -z "$max_mem2" ]; then echo "Could not find memory usage data in the output." else echo -e "Model: $model_name" echo -e "ZO2 peak GPU memory: ${GREEN}$max_mem1 MB${NC}" echo -e "ZO2 peak CPU memory: ${GREEN}$max_mem2 MB${NC}" fi rm $OUT2 done ================================================ FILE: test/mezo_sgd/nanogpt/record_zo2_speed.sh ================================================ #!/bin/bash set -e set -o pipefail model_ids=("gpt2" "gpt2_medium" "gpt2_large" "gpt2_xl" "opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b") # ANSI color codes GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' for model_id in "${model_ids[@]}" do echo "Testing model_id: $model_id" CMD2="python test/mezo_sgd/nanogpt/test_speed.py --model_id $model_id --zo_method zo2 --max_steps 30" OUT2="/tmp/output2_$model_id.txt" $CMD2 2>&1 | tee $OUT2 echo "Analyzing throughput..." # Count the total number of lines and determine the number of iteration lines total_lines2=$(wc -l < $OUT2) iter_lines2=$(grep -c 'Time cost after iteration' $OUT2) # Calculate the starting line for the last 50% of iterations start_line2=$(($total_lines2 - $iter_lines2 + $(($iter_lines2 / 2 + 1)))) # Calculate average tokens per second for the last 50% of the iterations avg_tok_s2=$(tail -n +$start_line2 $OUT2 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}') echo -e "Model: $model_name" echo -e "ZO2 average throughput (last 50% iterations): ${GREEN}$avg_tok_s2 tok/s${NC}" rm $OUT2 done ================================================ FILE: test/mezo_sgd/nanogpt/test_acc.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import sys sys.path.append("../zo2") import torch from tqdm import tqdm from zo2.config.mezo_sgd import MeZOSGDConfig from zo2.model.nanogpt.mezo_sgd import get_nanogpt_mezo_sgd from zo2.model.nanogpt.model import GPTConfig, GPTConfigs from zo2.utils.utils import seed_everything from utils import model_size, prepare_data, get_args def train_mezo_sgd(model, args, model_config, device='cuda'): seed_everything(args.seed) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") print("Init dataset") input_ids, pos, labels = prepare_data(model_config.vocab_size, args.batch_size, model_config.block_size, device=device) for i in tqdm(range(args.max_steps)): model.zo_train() loss = model(input_ids, pos, labels) res = "Iteration {}, loss: {}, projected grad: {}" tqdm.write(res.format(i, loss, model.opt.projected_grad)) def train_mezo2_sgd(model, args, model_config, device='cuda'): seed_everything(args.seed) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") print("Init dataset") input_ids, pos, labels = prepare_data(model_config.vocab_size, args.batch_size, model_config.block_size, device=device) for i in tqdm(range(args.max_steps)): model.zo_train() loss = model(input_ids, pos, labels) res = "Iteration {}, loss: {}, projected grad: {}" tqdm.write(res.format(i, loss, model.opt.projected_grad)) def eval_mezo_sgd(model, args, model_config, device='cuda'): seed_everything(args.seed) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") print("Init dataset") input_ids, pos, labels = prepare_data(model_config.vocab_size, args.batch_size, model_config.block_size, device=device) for i in tqdm(range(args.max_steps)): model.zo_eval() loss = model(input_ids, pos, labels)[-1] res = "Iteration {}, loss: {}" tqdm.write(res.format(i, loss)) def eval_mezo2_sgd(model, args, model_config, device='cuda'): seed_everything(args.seed) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") print("Init dataset") input_ids, pos, labels = prepare_data(model_config.vocab_size, args.batch_size, model_config.block_size, device=device) for i in tqdm(range(args.max_steps)): model.zo_eval() loss = model(input_ids, pos, labels)[-1] res = "Iteration {}, loss: {}" tqdm.write(res.format(i, loss)) def test_mezo_sgd_training(): seed_everything(args.seed) cfgs = GPTConfigs() cfg = getattr(cfgs, args.model_id) zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False torch.set_default_dtype(args.model_dtype) model_mezo = get_nanogpt_mezo_sgd(zo_cfg)(cfg, zo_cfg).to(args.working_device) torch.set_default_dtype(original_dtype) train_mezo_sgd(model=model_mezo, args=args, model_config=cfg, device=args.working_device) def test_mezo2_sgd_training(): seed_everything(args.seed) cfgs = GPTConfigs() cfg = getattr(cfgs, args.model_id) zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True torch.set_default_dtype(args.model_dtype) model = get_nanogpt_mezo_sgd(zo_cfg)(cfg, zo_cfg) torch.set_default_dtype(original_dtype) train_mezo2_sgd(model=model, args=args, model_config=cfg, device=args.working_device) def test_mezo_sgd_eval(): seed_everything(args.seed) cfgs = GPTConfigs() cfg = getattr(cfgs, args.model_id) zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False torch.set_default_dtype(args.model_dtype) model_mezo = get_nanogpt_mezo_sgd(zo_cfg)(cfg, zo_cfg).to(args.working_device) torch.set_default_dtype(original_dtype) eval_mezo_sgd(model=model_mezo, args=args, model_config=cfg, device=args.working_device) def test_mezo2_sgd_eval(): seed_everything(args.seed) cfgs = GPTConfigs() cfg = getattr(cfgs, args.model_id) zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True torch.set_default_dtype(args.model_dtype) model = get_nanogpt_mezo_sgd(zo_cfg)(cfg, zo_cfg) torch.set_default_dtype(original_dtype) eval_mezo2_sgd(model=model, args=args, model_config=cfg, device=args.working_device) if __name__ == "__main__": args = get_args() original_dtype = torch.get_default_dtype() if args.zo_method == "zo": if args.eval: test_mezo_sgd_eval() else: test_mezo_sgd_training() elif args.zo_method == "zo2": if args.eval: test_mezo2_sgd_eval() else: test_mezo2_sgd_training() else: raise NotImplementedError ================================================ FILE: test/mezo_sgd/nanogpt/test_acc_eval.sh ================================================ #!/bin/bash set -e set -o pipefail model_ids=("gpt2" "gpt2_medium" "gpt2_large" "gpt2_xl" "opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b") # ANSI color codes GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' for model_id in "${model_ids[@]}" do echo "Testing model_id: $model_id" CMD1="python test/mezo_sgd/nanogpt/test_acc.py --model_id $model_id --zo_method zo --eval" CMD2="python test/mezo_sgd/nanogpt/test_acc.py --model_id $model_id --zo_method zo2 --eval" OUT1="/tmp/output1_$model_id.txt" OUT2="/tmp/output2_$model_id.txt" $CMD1 2>&1 | tee $OUT1 $CMD2 2>&1 | tee $OUT2 echo "Comparing outputs..." echo -e "Model: $model_id" paste <(grep 'Iteration' $OUT1) <(grep 'Iteration' $OUT2) | awk -v green="$GREEN" -v red="$RED" -v nc="$NC" '{ split($2, loss1, ":"); split($7, loss2, ":"); diff_loss = loss1[2] - loss2[2]; if (loss1[2] == loss2[2]) printf "Iteration %s: %s✓ loss match.%s\n", $2, green, nc; else printf "Iteration %s: %s✗ Mismatch! ZO (loss): (%s), ZO2 (loss): (%s) \tLoss diff: %.6f%s\n", $2, red, loss1[2], loss2[2], diff_loss, nc; }' rm $OUT1 $OUT2 done ================================================ FILE: test/mezo_sgd/nanogpt/test_acc_train.sh ================================================ #!/bin/bash set -e set -o pipefail model_ids=("gpt2" "gpt2_medium" "gpt2_large" "gpt2_xl" "opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b") # ANSI color codes GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' for model_id in "${model_ids[@]}" do echo "Testing model_id: $model_id" CMD1="python test/mezo_sgd/nanogpt/test_acc.py --model_id $model_id --zo_method zo" CMD2="python test/mezo_sgd/nanogpt/test_acc.py --model_id $model_id --zo_method zo2" OUT1="/tmp/output1_$model_id.txt" OUT2="/tmp/output2_$model_id.txt" $CMD1 2>&1 | tee $OUT1 $CMD2 2>&1 | tee $OUT2 echo "Comparing outputs..." echo -e "Model: $model_name" paste <(grep 'Iteration' $OUT1) <(grep 'Iteration' $OUT2) | awk -v green="$GREEN" -v red="$RED" -v nc="$NC" '{ split($4, loss1, ","); split($7, proj1, ","); split($11, loss2, ","); split($14, proj2, ","); diff_loss = loss1[1] - loss2[1]; diff_proj = proj1[1] - proj2[1]; if (loss1[1] == loss2[1] && proj1[1] == proj2[1]) printf "Iteration %d: %s✓ loss and projected grad match.%s\n", $2, green, nc; else printf "Iteration %d: %s✗ Mismatch! ZO (loss, grad): (%s, %s), ZO2 (loss, grad): (%s, %s)\n \tLoss diff: %.6f, Proj grad diff: %.6f%s\n", $2, red, loss1[1], proj1[1], loss2[1], proj2[1], diff_loss, diff_proj, nc; }' rm $OUT1 $OUT2 done ================================================ FILE: test/mezo_sgd/nanogpt/test_memory.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import sys sys.path.append("../zo2") import torch from tqdm import tqdm from zo2.config.mezo_sgd import MeZOSGDConfig from zo2.model.nanogpt.mezo_sgd import get_nanogpt_mezo_sgd from zo2.model.nanogpt.model import GPTConfig, GPTConfigs from zo2.utils.utils import seed_everything from utils import model_size, prepare_data, get_args, check_peak_gpu_memory_usage, reset_peak_cpu_memory_usage, check_and_update_peak_cpu_memory_usage def train_mezo_sgd(model, args, modelConfig, device='cuda:0'): seed_everything(args.seed) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") print("Init dataset") input_ids, pos, labels = prepare_data(modelConfig.vocab_size, args.batch_size, modelConfig.block_size, device=device) torch.cuda.reset_peak_memory_stats() reset_peak_cpu_memory_usage() for i in tqdm(range(args.max_steps)): model.zo_train() model(input_ids, pos, labels) check_peak_gpu_memory_usage(i, int(device[-1]), True) check_and_update_peak_cpu_memory_usage(i, True) def train_mezo2_sgd(model, args, modelConfig, device='cuda:0'): seed_everything(args.seed) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") print("Init dataset") input_ids, pos, labels = prepare_data(modelConfig.vocab_size, args.batch_size, modelConfig.block_size, device=device) torch.cuda.reset_peak_memory_stats() reset_peak_cpu_memory_usage() for i in tqdm(range(args.max_steps)): model.zo_train() model(input_ids, pos, labels) check_peak_gpu_memory_usage(i, int(device[-1]), True) check_and_update_peak_cpu_memory_usage(i, True) def eval_mezo_sgd(model, args, modelConfig, device='cuda:0'): seed_everything(args.seed) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") print("Init dataset") input_ids, pos, labels = prepare_data(modelConfig.vocab_size, args.batch_size, modelConfig.block_size, device=device) torch.cuda.reset_peak_memory_stats() reset_peak_cpu_memory_usage() for i in tqdm(range(args.max_steps)): model.zo_eval() model(input_ids, pos, labels) check_peak_gpu_memory_usage(i, int(device[-1]), True) check_and_update_peak_cpu_memory_usage(i, True) def eval_mezo2_sgd(model, args, modelConfig, device='cuda:0'): seed_everything(args.seed) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") print("Init dataset") input_ids, pos, labels = prepare_data(modelConfig.vocab_size, args.batch_size, modelConfig.block_size, device=device) torch.cuda.reset_peak_memory_stats() reset_peak_cpu_memory_usage() for i in tqdm(range(args.max_steps)): model.zo_eval() model(input_ids, pos, labels) check_peak_gpu_memory_usage(i, int(device[-1]), True) check_and_update_peak_cpu_memory_usage(i, True) def test_mezo_sgd_training(): seed_everything(args.seed) cfgs = GPTConfigs() cfg = getattr(cfgs, args.model_id) zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False torch.set_default_dtype(args.model_dtype) model_mezo = get_nanogpt_mezo_sgd(zo_cfg)(cfg, zo_cfg).to(args.working_device) torch.set_default_dtype(original_dtype) train_mezo_sgd(model=model_mezo, args=args, modelConfig=cfg, device=args.working_device) def test_mezo2_sgd_training(): seed_everything(args.seed) cfgs = GPTConfigs() cfg = getattr(cfgs, args.model_id) zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True torch.set_default_dtype(args.model_dtype) model = get_nanogpt_mezo_sgd(zo_cfg)(cfg, zo_cfg) torch.set_default_dtype(original_dtype) train_mezo2_sgd(model=model, args=args, modelConfig=cfg, device=args.working_device) def test_mezo_sgd_eval(): seed_everything(args.seed) cfgs = GPTConfigs() cfg = getattr(cfgs, args.model_id) zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False torch.set_default_dtype(args.model_dtype) model_mezo = get_nanogpt_mezo_sgd(zo_cfg)(cfg, zo_cfg).to(args.working_device) torch.set_default_dtype(original_dtype) eval_mezo_sgd(model=model_mezo, args=args, modelConfig=cfg, device=args.working_device) def test_mezo2_sgd_eval(): seed_everything(args.seed) cfgs = GPTConfigs() cfg = getattr(cfgs, args.model_id) zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True torch.set_default_dtype(args.model_dtype) model = get_nanogpt_mezo_sgd(zo_cfg)(cfg, zo_cfg) torch.set_default_dtype(original_dtype) eval_mezo2_sgd(model=model, args=args, modelConfig=cfg, device=args.working_device) if __name__ == "__main__": args = get_args() original_dtype = torch.get_default_dtype() if args.zo_method == "zo": if args.eval: test_mezo_sgd_eval() else: test_mezo_sgd_training() elif args.zo_method == "zo2": if args.eval: test_mezo2_sgd_eval() else: test_mezo2_sgd_training() else: raise NotImplementedError ================================================ FILE: test/mezo_sgd/nanogpt/test_memory_eval.sh ================================================ #!/bin/bash set -e set -o pipefail model_ids=("gpt2" "gpt2_medium" "gpt2_large" "gpt2_xl" "opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b") # ANSI color codes GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' for model_id in "${model_ids[@]}" do echo "Testing model_id: $model_id" CMD1="python test/mezo_sgd/nanogpt/test_memory.py --model_id $model_id --zo_method zo --max_steps 30 --eval" CMD2="python test/mezo_sgd/nanogpt/test_memory.py --model_id $model_id --zo_method zo2 --max_steps 30 --eval" OUT1="/tmp/output1_$model_id.txt" OUT2="/tmp/output2_$model_id.txt" $CMD1 2>&1 | tee $OUT1 $CMD2 2>&1 | tee $OUT2 echo "Analyzing Peak GPU Memory usage..." max_mem1=$(grep 'Peak GPU Memory' $OUT1 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1) max_mem2=$(grep 'Peak GPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1) if [ -z "$max_mem1" ] || [ -z "$max_mem2" ]; then echo "Could not find memory usage data in the output." else ratio=$(echo "scale=2; $max_mem2 / $max_mem1 * 100" | bc) echo -e "Model: $model_name" echo -e "ZO peak GPU memory: ${GREEN}$max_mem1 MB${NC}" echo -e "ZO2 peak GPU memory: ${GREEN}$max_mem2 MB${NC}" echo -e "Memory usage ratio of ZO2 to ZO: ${GREEN}$ratio%${NC}" fi rm $OUT1 $OUT2 done ================================================ FILE: test/mezo_sgd/nanogpt/test_memory_train.sh ================================================ #!/bin/bash set -e set -o pipefail model_ids=("gpt2" "gpt2_medium" "gpt2_large" "gpt2_xl" "opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b") # ANSI color codes GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' for model_id in "${model_ids[@]}" do echo "Testing model_id: $model_id" CMD1="python test/mezo_sgd/nanogpt/test_memory.py --model_id $model_id --zo_method zo --max_steps 30" CMD2="python test/mezo_sgd/nanogpt/test_memory.py --model_id $model_id --zo_method zo2 --max_steps 30" OUT1="/tmp/output1_$model_id.txt" OUT2="/tmp/output2_$model_id.txt" $CMD1 2>&1 | tee $OUT1 $CMD2 2>&1 | tee $OUT2 echo "Analyzing Peak GPU Memory usage..." max_mem1=$(grep 'Peak GPU Memory' $OUT1 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1) max_mem2=$(grep 'Peak GPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1) if [ -z "$max_mem1" ] || [ -z "$max_mem2" ]; then echo "Could not find memory usage data in the output." else ratio=$(echo "scale=2; $max_mem2 / $max_mem1 * 100" | bc) echo -e "Model: $model_name" echo -e "ZO peak GPU memory: ${GREEN}$max_mem1 MB${NC}" echo -e "ZO2 peak GPU memory: ${GREEN}$max_mem2 MB${NC}" echo -e "Memory usage ratio of ZO2 to ZO: ${GREEN}$ratio%${NC}" fi rm $OUT1 $OUT2 done ================================================ FILE: test/mezo_sgd/nanogpt/test_speed.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import sys sys.path.append("../zo2") import torch from tqdm import tqdm from zo2.config.mezo_sgd import MeZOSGDConfig from zo2.model.nanogpt.mezo_sgd import get_nanogpt_mezo_sgd from zo2.model.nanogpt.model import GPTConfig, GPTConfigs from zo2.utils.utils import seed_everything from utils import model_size, prepare_data, get_args, check_throughput def train_mezo_sgd(model, args, modelConfig, device='cuda'): seed_everything(args.seed) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") print("Init dataset") input_ids, pos, labels = prepare_data(modelConfig.vocab_size, args.batch_size, modelConfig.block_size, device=device) for i in tqdm(range(args.max_steps)): model.zo_train() check_throughput(i, args.batch_size*modelConfig.block_size, model, input_ids, pos, labels, use_tqdm=True) def train_mezo2_sgd(model, args, modelConfig, device='cuda'): seed_everything(args.seed) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") print("Init dataset") input_ids, pos, labels = prepare_data(modelConfig.vocab_size, args.batch_size, modelConfig.block_size, device=device) for i in tqdm(range(args.max_steps)): model.zo_train() check_throughput(i, args.batch_size*modelConfig.block_size, model, input_ids, pos, labels, use_tqdm=True) def eval_mezo_sgd(model, args, modelConfig, device='cuda'): seed_everything(args.seed) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") print("Init dataset") input_ids, pos, labels = prepare_data(modelConfig.vocab_size, args.batch_size, modelConfig.block_size, device=device) for i in tqdm(range(args.max_steps)): model.zo_eval() check_throughput(i, args.batch_size*modelConfig.block_size, model, input_ids, pos, labels, use_tqdm=True) def eval_mezo2_sgd(model, args, modelConfig, device='cuda'): seed_everything(args.seed) total_parameters = model_size(model)["total"] print(f"model size: {total_parameters/1024**3:.2f} B") print("Init dataset") input_ids, pos, labels = prepare_data(modelConfig.vocab_size, args.batch_size, modelConfig.block_size, device=device) for i in tqdm(range(args.max_steps)): model.zo_eval() check_throughput(i, args.batch_size*modelConfig.block_size, model, input_ids, pos, labels, use_tqdm=True) def test_mezo_sgd_training(): seed_everything(args.seed) cfgs = GPTConfigs() cfg = getattr(cfgs, args.model_id) zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False torch.set_default_dtype(args.model_dtype) model_mezo = get_nanogpt_mezo_sgd(zo_cfg)(cfg, zo_cfg).to(args.working_device) torch.set_default_dtype(original_dtype) train_mezo_sgd(model=model_mezo, args=args, modelConfig=cfg, device=args.working_device) def test_mezo2_sgd_training(): seed_everything(args.seed) cfgs = GPTConfigs() cfg = getattr(cfgs, args.model_id) zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True torch.set_default_dtype(args.model_dtype) model = get_nanogpt_mezo_sgd(zo_cfg)(cfg, zo_cfg) torch.set_default_dtype(original_dtype) train_mezo2_sgd(model=model, args=args, modelConfig=cfg, device=args.working_device) def test_mezo_sgd_eval(): seed_everything(args.seed) cfgs = GPTConfigs() cfg = getattr(cfgs, args.model_id) zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, working_device=args.working_device) zo_cfg.zo2 = False torch.set_default_dtype(args.model_dtype) model_mezo = get_nanogpt_mezo_sgd(zo_cfg)(cfg, zo_cfg).to(args.working_device) torch.set_default_dtype(original_dtype) eval_mezo_sgd(model=model_mezo, args=args, modelConfig=cfg, device=args.working_device) def test_mezo2_sgd_eval(): seed_everything(args.seed) cfgs = GPTConfigs() cfg = getattr(cfgs, args.model_id) zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps, offloading_device=args.offloading_device, working_device=args.working_device) zo_cfg.zo2 = True torch.set_default_dtype(args.model_dtype) model = get_nanogpt_mezo_sgd(zo_cfg)(cfg, zo_cfg) torch.set_default_dtype(original_dtype) eval_mezo2_sgd(model=model, args=args, modelConfig=cfg, device=args.working_device) if __name__ == "__main__": args = get_args() original_dtype = torch.get_default_dtype() if args.zo_method == "zo": if args.eval: test_mezo_sgd_eval() else: test_mezo_sgd_training() elif args.zo_method == "zo2": if args.eval: test_mezo2_sgd_eval() else: test_mezo2_sgd_training() else: raise NotImplementedError ================================================ FILE: test/mezo_sgd/nanogpt/test_speed_eval.sh ================================================ #!/bin/bash set -e set -o pipefail model_ids=("gpt2" "gpt2_medium" "gpt2_large" "gpt2_xl" "opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b") # ANSI color codes GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' for model_id in "${model_ids[@]}" do echo "Testing model_id: $model_id" CMD1="python test/mezo_sgd/nanogpt/test_speed.py --model_id $model_id --zo_method zo --max_steps 30 --eval" CMD2="python test/mezo_sgd/nanogpt/test_speed.py --model_id $model_id --zo_method zo2 --max_steps 30 --eval" OUT1="/tmp/output1_$model_id.txt" OUT2="/tmp/output2_$model_id.txt" $CMD1 2>&1 | tee $OUT1 $CMD2 2>&1 | tee $OUT2 echo "Analyzing throughput..." # Count the total number of lines and determine the number of iteration lines total_lines1=$(wc -l < $OUT1) total_lines2=$(wc -l < $OUT2) iter_lines1=$(grep -c 'Time cost after iteration' $OUT1) iter_lines2=$(grep -c 'Time cost after iteration' $OUT2) # Calculate the starting line for the last 50% of iterations start_line1=$(($total_lines1 - $iter_lines1 + $(($iter_lines1 / 2 + 1)))) start_line2=$(($total_lines2 - $iter_lines2 + $(($iter_lines2 / 2 + 1)))) # Calculate average tokens per second for the last 50% of the iterations avg_tok_s1=$(tail -n +$start_line1 $OUT1 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}') avg_tok_s2=$(tail -n +$start_line2 $OUT2 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}') ratio=$(echo "scale=2; $avg_tok_s2 / $avg_tok_s1 * 100" | bc) echo -e "Model: $model_name" echo -e "ZO average throughput (last 50% iterations): ${GREEN}$avg_tok_s1 tok/s${NC}" echo -e "ZO2 average throughput (last 50% iterations): ${GREEN}$avg_tok_s2 tok/s${NC}" echo -e "Throughput ratio of ZO2 to ZO (last 50% iterations): ${GREEN}$ratio%${NC}" rm $OUT1 $OUT2 done ================================================ FILE: test/mezo_sgd/nanogpt/test_speed_train.sh ================================================ #!/bin/bash set -e set -o pipefail model_ids=("gpt2" "gpt2_medium" "gpt2_large" "gpt2_xl" "opt_125m" "opt_350m" "opt_1_3b" "opt_2_7b" "opt_6_7b" "opt_13b" "opt_30b" "opt_66b" "opt_175b") # ANSI color codes GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' for model_id in "${model_ids[@]}" do echo "Testing model_id: $model_id" CMD1="python test/mezo_sgd/nanogpt/test_speed.py --model_id $model_id --zo_method zo --max_steps 30" CMD2="python test/mezo_sgd/nanogpt/test_speed.py --model_id $model_id --zo_method zo2 --max_steps 30" OUT1="/tmp/output1_$model_id.txt" OUT2="/tmp/output2_$model_id.txt" $CMD1 2>&1 | tee $OUT1 $CMD2 2>&1 | tee $OUT2 echo "Analyzing throughput..." # Count the total number of lines and determine the number of iteration lines total_lines1=$(wc -l < $OUT1) total_lines2=$(wc -l < $OUT2) iter_lines1=$(grep -c 'Time cost after iteration' $OUT1) iter_lines2=$(grep -c 'Time cost after iteration' $OUT2) # Calculate the starting line for the last 50% of iterations start_line1=$(($total_lines1 - $iter_lines1 + $(($iter_lines1 / 2 + 1)))) start_line2=$(($total_lines2 - $iter_lines2 + $(($iter_lines2 / 2 + 1)))) # Calculate average tokens per second for the last 50% of the iterations avg_tok_s1=$(tail -n +$start_line1 $OUT1 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}') avg_tok_s2=$(tail -n +$start_line2 $OUT2 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}') ratio=$(echo "scale=2; $avg_tok_s2 / $avg_tok_s1 * 100" | bc) echo -e "Model: $model_name" echo -e "ZO average throughput (last 50% iterations): ${GREEN}$avg_tok_s1 tok/s${NC}" echo -e "ZO2 average throughput (last 50% iterations): ${GREEN}$avg_tok_s2 tok/s${NC}" echo -e "Throughput ratio of ZO2 to ZO (last 50% iterations): ${GREEN}$ratio%${NC}" rm $OUT1 $OUT2 done ================================================ FILE: test/mezo_sgd/nanogpt/utils.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import torch import time import argparse from tqdm import tqdm import psutil import os import pynvml def get_args(): args = argparse.ArgumentParser() args.add_argument("--zo_method", type=str, default="zo2") args.add_argument("--eval", action="store_true") args.add_argument("--model_id", type=str, default="gpt2") args.add_argument("--model_dtype", type=str, default="fp32") args.add_argument("--verbose", action="store_true") args.add_argument("--max_steps", type=int, default=3) args.add_argument("--lr", type=float, default=1e-4) args.add_argument("--weight_decay", type=float, default=1e-1) args.add_argument("--zo_eps", type=float, default=1e-3) args.add_argument("--seed", type=int, default=42) args.add_argument("--batch_size", type=int, default=1) args.add_argument("--offloading_device", type=str, default="cpu") args.add_argument("--working_device", type=str, default="cuda:0") args = args.parse_args() args.model_dtype = dtype_lookup[args.model_dtype] return args dtype_lookup = { "fp64": torch.float64, "fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16 } def model_size(model: torch.nn.Module): total_size = sum(p.numel() for p in model.parameters()) trainable_size = sum(p.numel() for p in model.parameters() if p.requires_grad) return {"total": total_size, "trainable": trainable_size} def prepare_data(V, B, T, device='cuda'): data_batch = torch.randint(0, V, (B, T+1)).to(device) input_ids = data_batch[:, :T] labels = data_batch[:, 1:T+1] pos = torch.arange(input_ids.shape[1], dtype=torch.long, device=device).unsqueeze(0) return input_ids, pos, labels # GPU Memory Monitoring pynvml.nvmlInit() def check_peak_gpu_memory_usage(iter, device=0, use_tqdm=False): # Check the peak memory usage handle = pynvml.nvmlDeviceGetHandleByIndex(device) # Adjust index if necessary info = pynvml.nvmlDeviceGetMemoryInfo(handle) peak_memory = info.used / 1024**2 if use_tqdm: tqdm.write("Peak GPU Memory after iteration {}: {:.2f} MB".format(iter+1, peak_memory)) else: print(f"Peak GPU Memory after iteration {iter+1}: {peak_memory:.2f} MB") # CPU Memory Monitoring peak_memory_cpu = 0 def check_and_update_peak_cpu_memory_usage(iter, use_tqdm=False): global peak_memory_cpu process = psutil.Process(os.getpid()) current_memory = process.memory_info().rss / (1024 ** 2) # Convert to MB if current_memory > peak_memory_cpu: peak_memory_cpu = current_memory if use_tqdm: tqdm.write(f"Peak CPU Memory after iteration {iter+1}: {peak_memory_cpu:.2f} MB") else: print(f"Peak CPU Memory after iteration {iter+1}: {peak_memory_cpu:.2f} MB") def reset_peak_cpu_memory_usage(): global peak_memory_cpu peak_memory_cpu = 0 if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() def check_throughput(iter, total_token_batch_size_per_iter, fn, *args, use_tqdm=False, **kwargs): t1 = time.time() out = fn(*args, **kwargs) t2 = time.time() time_cost = t2-t1 throughtput = total_token_batch_size_per_iter / time_cost if use_tqdm: tqdm.write("Time cost after iteration {}: {:.2f} ms, {:.2f} tok/s".format(iter+1, time_cost*1e3, throughtput)) else: print("Time cost after iteration {}: {:.2f} ms, {:.2f} tok/s".format(iter+1, time_cost*1e3, throughtput)) ================================================ FILE: tutorial/README.md ================================================ # API of ZO2 Welcome to the ZO2 API documentation! ## Standard Usage ### 1. Quick Start For a straightforward introduction to using ZO2, refer to the Jupyter notebook: [demo.ipynb](./demo.ipynb) ### 2. Huggingface Trainer To see how ZO2 can be integrated with the Huggingface Trainer for efficient model training, check out: [huggingface.ipynb](./huggingface.ipynb) ## 3. Extend ZO2 to Your Own PyTorch Models Learn how to apply ZO2 to your own PyTorch models by following the example of building a nanogpt model: [nanogpt.ipynb](./nanogpt.ipynb). ================================================ FILE: tutorial/colab.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Environment Setting" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "# Set CUDA_VISIBLE_DEVICES to 0 to make only the first GPU visible\n", "os.environ['CUDA_VISIBLE_DEVICES'] = '0'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install -q condacolab\n", "import condacolab\n", "condacolab.install()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import condacolab\n", "condacolab.check()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "!rm -rf zo2/\n", "!git clone https://github.com/liangyuwang/zo2.git\n", "print(\"Current working directory:\", os.getcwd())\n", "os.chdir('zo2/')\n", "print(\"New working directory:\", os.getcwd())\n", "\n", "!conda env update -n base -f env.yml" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using [MeZO Runner](../example/mezo_runner/) on Supported Tasks" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "print(\"Current working directory:\", os.getcwd())\n", "os.chdir('./example/mezo_runner/')\n", "print(\"New working directory:\", os.getcwd())\n", "\n", "!MODEL=facebook/opt-2.7b TASK=SST2 MODE=ft LR=1e-7 EPS=1e-3 STEPS=20000 EVAL_STEPS=4000 bash mezo.sh\n", "\n", "os.chdir('../../tutorial/')\n", "print(\"New working directory:\", os.getcwd())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using Huggingface Trainer" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append(\"../\")\n", "\n", "from tqdm.auto import tqdm\n", "import torch\n", "from transformers import (\n", " AutoTokenizer, \n", " TrainingArguments,\n", " DataCollatorForLanguageModeling\n", ")\n", "from datasets import load_dataset\n", "from zo2 import (\n", " ZOConfig,\n", " zo_hf_init,\n", ")\n", "from zo2.trainer.hf_transformers.trainer import ZOTrainer\n", "from zo2.trainer.hf_trl.sft_trainer import ZOSFTTrainer\n", "from zo2.utils import seed_everything" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Hyperparameter\n", "zo_method = \"zo2\"\n", "eval_mode = False\n", "model_name = \"facebook/opt-2.7b\"\n", "verbose = True\n", "max_steps = 300\n", "learning_rate = 1e-7\n", "weight_decay = 1e-1\n", "zo_eps = 1e-3\n", "seed = 42\n", "offloading_device = \"cpu\"\n", "working_device = \"cuda:0\"\n", "max_train_data = None\n", "max_eval_data = None\n", "use_cache = True\n", "max_new_tokens = 50\n", "temperature = 1.0\n", "seed_everything(seed)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# ZO steps\n", "zo_config = ZOConfig(\n", " method=\"mezo-sgd\", \n", " zo2=zo_method==\"zo2\", \n", " lr=learning_rate,\n", " weight_decay=weight_decay,\n", " eps=zo_eps,\n", " offloading_device=offloading_device,\n", " working_device=working_device,\n", ")\n", "\n", "# Load ZO model\n", "with zo_hf_init(zo_config):\n", " from transformers import OPTForCausalLM\n", " model = OPTForCausalLM.from_pretrained(model_name)\n", " model.zo_init(zo_config)\n", "if zo_method != \"zo2\": \n", " model = model.to(working_device)\n", "print(f\"Check if zo2 init correctly: {hasattr(model, 'zo_training')}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Prepare dataset\n", "dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')\n", "\n", "# tokenizing dataset\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "block_size = tokenizer.model_max_length\n", "def tokenize_function(examples):\n", " return tokenizer(examples[\"text\"])\n", "tokenized_datasets = dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=[\"text\"])\n", "def group_texts(examples):\n", " # Concatenate all texts.\n", " concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n", " total_length = len(concatenated_examples[list(examples.keys())[0]])\n", " # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n", " # customize this part to your needs.\n", " total_length = (total_length // block_size) * block_size\n", " # Split by chunks of max_len.\n", " result = {\n", " k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n", " for k, t in concatenated_examples.items()\n", " }\n", " result[\"labels\"] = result[\"input_ids\"].copy()\n", " return result\n", "lm_datasets = tokenized_datasets.map(\n", " group_texts,\n", " batched=True,\n", " batch_size=1000,\n", " num_proc=4,\n", ")\n", "data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# trainer init\n", "training_args = TrainingArguments(\n", " \"test-trainer\", \n", " max_steps=max_steps,\n", " save_strategy=\"no\", \n", " logging_steps=10,\n", ")\n", "\n", "trainer = ZOTrainer(\n", " model,\n", " training_args,\n", " train_dataset=tokenized_datasets[\"train\"],\n", " data_collator=data_collator,\n", " processing_class=tokenizer,\n", ")\n", "\n", "# 'ZOTrainer' provides the capability to register pre-hooks and post-hooks during zo_step\n", "def drop_invalid_data(model, inputs, loss):\n", " # Extract projected_grad, handle both tensor and scalar cases\n", " projected_grad = model.opt.projected_grad\n", " if isinstance(projected_grad, torch.Tensor):\n", " projected_grad_is_nan = torch.isnan(projected_grad).any()\n", " else:\n", " projected_grad_is_nan = projected_grad != projected_grad # Check for NaN in scalars\n", " if torch.isnan(loss) or projected_grad_is_nan:\n", " tqdm.write(\"'loss': {} or 'projected_grad': {} is nan. Drop this step.\".format(\n", " loss, model.opt.projected_grad\n", " ))\n", " model.opt.projected_grad = 0 # Reset projected_grad to prevent parameter updates\n", " return model, inputs, loss\n", "trainer.register_zo2_training_step_post_hook(drop_invalid_data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# trainer step\n", "trainer.train()" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: tutorial/demo.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Fine-tune HF Model with Your Custom Training Loop" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append(\"../\")\n", "\n", "from tqdm.auto import tqdm\n", "import torch\n", "from transformers import AutoTokenizer\n", "from zo2 import (\n", " ZOConfig,\n", " zo_hf_init,\n", ")\n", "from zo2.utils import seed_everything" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Hyperparameter\n", "zo_method = \"zo2\"\n", "eval_mode = False\n", "model_name = \"facebook/opt-2.7b\"\n", "verbose = True\n", "max_steps = 100\n", "learning_rate = 1e-5\n", "weight_decay = 1e-1\n", "zo_eps = 1e-3\n", "seed = 42\n", "offloading_device = \"cpu\"\n", "working_device = \"cuda:0\"\n", "use_cache = True\n", "max_new_tokens = 50\n", "temperature = 1.0\n", "seed_everything(seed)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# ZO steps\n", "zo_config = ZOConfig(\n", " method=\"mezo-sgd\", \n", " zo2=zo_method==\"zo2\", \n", " lr=learning_rate,\n", " weight_decay=weight_decay,\n", " eps=zo_eps,\n", " offloading_device=offloading_device,\n", " working_device=working_device,\n", ")\n", "\n", "# Load ZO model\n", "with zo_hf_init(zo_config):\n", " from transformers import OPTForCausalLM\n", " model = OPTForCausalLM.from_pretrained(model_name)\n", " model.zo_init(zo_config)\n", "if zo_method != \"zo2\": \n", " model = model.to(working_device)\n", "print(f\"Check if zo2 init correctly: {hasattr(model, 'zo_training')}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Prepare some data\n", "dataset = \"\"\"\n", " What is ZO2? \n", " ZO2 is an innovative framework specifically designed to enhance the fine-tuning of large language models (LLMs) using zeroth-order (ZO) optimization techniques and advanced offloading technologies. \n", " This framework is particularly tailored for setups with limited GPU memory, enabling the fine-tuning of models that were previously unmanageable due to hardware constraints. \n", " As the scale of Large Language Models (LLMs) continues to grow, reaching parameter counts in the hundreds of billions, managing GPU memory resources effectively becomes crucial. \n", " Efficient GPU memory management is crucial not only because it directly influences model performance and training speed, but also because GPU memory is both expensive and limited in quantity. \n", " However, this creates a significant challenge in handling ever-larger models within the physical constraints of current hardware technologies. \n", " CPU offloading has become a crucial technique for overcoming this challenge. \n", " It involves transferring computations and data from the GPU to the CPU, specifically targeting data or parameters that are less frequently accessed. \n", " By offloading these inactive tensors of the neural network, CPU offloading effectively alleviates the memory and computational pressures on GPUs. \n", " While CPU offloading has been commonly applied in inference to manage memory-intensive tasks, its application in training, especially fine-tuning, remains less explored. \n", " Recently, some works have tried to introduce CPU offloading into LLM training. \n", " However, they are typically constrained by the capabilities of first-order optimizers such as SGD and Adaptive Moment Estimation (AdamW), and limited GPU memory, restricting large-scale model scalability on single GPU systems. \n", " Using first-order optimizers introduces inefficiencies in CPU offloading: Multiple communication operations during the training of LLMs necessitate offloading the same data twice—once for each pass. \n", " This redundancy not only doubles the communication volume between the CPU and GPU but also introduces significant latency due to repetitive data transfers. \n", " Furthermore, both parameters and activations are required in the backward pass to complete gradient computations. \n", " This means that parameters and activation values must be offloaded during each forward pass and re-uploaded to the GPU for the backward pass, increasing the volume of data transferred, which severely impacts training throughput. \n", " On the other hand, zeroth-order (ZO) methods offer a novel approach to fine-tuning LLMs. \n", " These methods utilize dual forward passes to estimate parameter gradients and subsequently update parameters. \n", " This approach eliminates the traditional reliance on backward passes, thereby streamlining the training process by significantly reducing the number of computational steps required. \n", " Based on these observations, we conjecture that ZO's architecture is particularly well-suited for CPU offloading strategies. \n", " By eliminating backward passes and the need to store activation values, it can significantly reduce GPU memory demands through efficient parameter offloading. \n", " However, despite these advantages, ZO training via CPU offloading introduces new challenges, particularly in the realm of CPU-to-GPU communication. \n", " Transferring parameters between the CPU and GPU, which is crucial for maintaining gradient computation and model updates, becomes a critical bottleneck. \n", " Although ZO methods inherently extend computation times because of the dual forward passes, potentially allowing for better overlap between computation and communication, there remain significant inefficiencies. \n", " The necessity to upload parameters to the GPU for upcoming computations introduces a large volume of communications. To tackle the inefficiencies highlighted, we introduce ZO2, a novel framework specifically designed for ZO fine-tuning in LLMs with CPU offloading. \n", " This framework utilizes the unique dual forward pass architecture of ZO methods to optimize interactions between CPU and GPU, significantly enhancing both computational and communication efficiency. \n", " By building a high-performance dynamic scheduler, ZO2 achieves substantial overlaps in communication and computation. \n", " These innovations make it feasible to fine-tune extremely large models, such as the OPT-175B, with over 175 billion parameters, on a single GPU equipped with just 18GB of memory usage—a capability previously unattainable with conventional methods. \n", " Additionally, our efficient framework operates without any extra time cost and decreases in accuracy compared to standard ZO methodologies.\"\"\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "data_batch = tokenizer(dataset, add_special_tokens=True, return_tensors='pt').input_ids.to(working_device)\n", "T = min(data_batch.shape[1] - 1, model.config.max_position_embeddings)\n", "print(f\"Fine-tuning model {model_name} with {T} tokens dataset: \\n{dataset}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Training loop\n", "for i in tqdm(range(max_steps)):\n", " model.zo_train()\n", " loss = model(input_ids=data_batch, labels=data_batch)\n", "\n", " # eval\n", " if eval_mode:\n", " if i==0:\n", " tqdm.write(\"Warning: please notice that ZO2 does not optimize the evaluation, so it may be very slow.\")\n", " model.zo_eval()\n", " output = model(input_ids=data_batch, labels=data_batch)\n", " res = \"Iteration {}, train loss: {}, projected grad: {}, eval loss: {}\"\n", " tqdm.write(res.format(i, loss, model.opt.projected_grad, output[\"loss\"]))\n", " else:\n", " res = \"Iteration {}, train loss: {}, projected grad: {}\"\n", " tqdm.write(res.format(i, loss, model.opt.projected_grad))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# inference\n", "print(\"Doing inference...\")\n", "print(\"Warning: please notice that ZO2 does not optimize the inference, so it may be very slow.\")\n", "model.zo_eval()\n", "prompt = \"What is ZO2 and how ZO2 enhance the fine-tuning of large language models?\"\n", "inputs = tokenizer(prompt, return_tensors='pt').to(working_device)\n", "inputs = {\"input_ids\": inputs.input_ids}\n", "for _ in tqdm(range(max_new_tokens)):\n", " outputs = model(**inputs, return_dict=True)\n", " next_token_logits = outputs.logits[:, -1, :]\n", " if temperature == 1.0:\n", " next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)\n", " else:\n", " scaled_logits = next_token_logits / temperature\n", " probs = torch.nn.functional.softmax(scaled_logits, dim=-1)\n", " next_token = torch.multinomial(probs, num_samples=1)\n", " inputs = torch.cat([inputs[\"input_ids\"], next_token], dim=-1)\n", " generated_text = tokenizer.decode(inputs[0])\n", " inputs = {\"input_ids\": inputs}\n", "print(f\"Question: {prompt}\")\n", "print(f\"Response: {generated_text[len(prompt)+4:]}...\")" ] } ], "metadata": { "kernelspec": { "display_name": "mezo", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: tutorial/huggingface.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Environment Setting" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "# Set CUDA_VISIBLE_DEVICES to 0 to make only the first GPU visible\n", "os.environ['CUDA_VISIBLE_DEVICES'] = '0'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using [MeZO Runner](../example/mezo_runner/) on Supported Tasks" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "print(\"Current working directory:\", os.getcwd())\n", "os.chdir('../example/mezo_runner/')\n", "print(\"New working directory:\", os.getcwd())\n", "\n", "!MODEL=facebook/opt-2.7b TASK=SST2 MODE=ft LR=1e-7 EPS=1e-3 STEPS=20000 EVAL_STEPS=4000 bash mezo.sh\n", "\n", "os.chdir('../../tutorial/')\n", "print(\"New working directory:\", os.getcwd())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using Huggingface Trainer" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append(\"../\")\n", "\n", "from tqdm.auto import tqdm\n", "import torch\n", "from transformers import (\n", " AutoTokenizer, \n", " TrainingArguments,\n", " DataCollatorForLanguageModeling\n", ")\n", "from datasets import load_dataset\n", "from zo2 import (\n", " ZOConfig,\n", " zo_hf_init,\n", ")\n", "from zo2.trainer.hf_transformers.trainer import ZOTrainer\n", "from zo2.trainer.hf_trl.sft_trainer import ZOSFTTrainer\n", "from zo2.utils import seed_everything" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Hyperparameter\n", "zo_method = \"zo2\"\n", "eval_mode = False\n", "model_name = \"facebook/opt-2.7b\"\n", "verbose = True\n", "max_steps = 300\n", "learning_rate = 1e-7\n", "weight_decay = 1e-1\n", "zo_eps = 1e-3\n", "seed = 42\n", "offloading_device = \"cpu\"\n", "working_device = \"cuda:0\"\n", "max_train_data = None\n", "max_eval_data = None\n", "use_cache = True\n", "max_new_tokens = 50\n", "temperature = 1.0\n", "seed_everything(seed)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# ZO steps\n", "zo_config = ZOConfig(\n", " method=\"mezo-sgd\", \n", " zo2=zo_method==\"zo2\", \n", " lr=learning_rate,\n", " weight_decay=weight_decay,\n", " eps=zo_eps,\n", " offloading_device=offloading_device,\n", " working_device=working_device,\n", ")\n", "\n", "# Load ZO model\n", "with zo_hf_init(zo_config):\n", " from transformers import OPTForCausalLM\n", " model = OPTForCausalLM.from_pretrained(model_name)\n", " model.zo_init(zo_config)\n", "if zo_method != \"zo2\": \n", " model = model.to(working_device)\n", "print(f\"Check if zo2 init correctly: {hasattr(model, 'zo_training')}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Prepare dataset\n", "dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')\n", "\n", "# tokenizing dataset\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "block_size = tokenizer.model_max_length\n", "def tokenize_function(examples):\n", " return tokenizer(examples[\"text\"])\n", "tokenized_datasets = dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=[\"text\"])\n", "def group_texts(examples):\n", " # Concatenate all texts.\n", " concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n", " total_length = len(concatenated_examples[list(examples.keys())[0]])\n", " # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n", " # customize this part to your needs.\n", " total_length = (total_length // block_size) * block_size\n", " # Split by chunks of max_len.\n", " result = {\n", " k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n", " for k, t in concatenated_examples.items()\n", " }\n", " result[\"labels\"] = result[\"input_ids\"].copy()\n", " return result\n", "lm_datasets = tokenized_datasets.map(\n", " group_texts,\n", " batched=True,\n", " batch_size=1000,\n", " num_proc=4,\n", ")\n", "data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# trainer init\n", "training_args = TrainingArguments(\n", " \"test-trainer\", \n", " max_steps=max_steps,\n", " save_strategy=\"no\", \n", " logging_steps=10,\n", ")\n", "\n", "trainer = ZOTrainer(\n", " model,\n", " training_args,\n", " train_dataset=tokenized_datasets[\"train\"],\n", " data_collator=data_collator,\n", " processing_class=tokenizer,\n", ")\n", "\n", "# 'ZOTrainer' provides the capability to register pre-hooks and post-hooks during zo_step\n", "def drop_invalid_data(model, inputs, loss):\n", " # Extract projected_grad, handle both tensor and scalar cases\n", " projected_grad = model.opt.projected_grad\n", " if isinstance(projected_grad, torch.Tensor):\n", " projected_grad_is_nan = torch.isnan(projected_grad).any()\n", " else:\n", " projected_grad_is_nan = projected_grad != projected_grad # Check for NaN in scalars\n", " if torch.isnan(loss) or projected_grad_is_nan:\n", " tqdm.write(\"'loss': {} or 'projected_grad': {} is nan. Drop this step.\".format(\n", " loss, model.opt.projected_grad\n", " ))\n", " model.opt.projected_grad = 0 # Reset projected_grad to prevent parameter updates\n", " return model, inputs, loss\n", "trainer.register_zo2_training_step_post_hook(drop_invalid_data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# trainer step\n", "trainer.train()" ] } ], "metadata": { "kernelspec": { "display_name": "mezo", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: tutorial/nanogpt.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Environment Setting" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append(\"../\")\n", "\n", "import math\n", "import inspect\n", "from dataclasses import dataclass\n", "from tqdm.auto import tqdm\n", "\n", "import torch\n", "import torch.nn as nn\n", "from torch.nn import functional as F\n", "\n", "from zo2 import ZOConfig\n", "from zo2.model.base import BaseZOModel\n", "from zo2.optimizer.mezo_sgd.zo import MeZOSGD\n", "from zo2.optimizer.mezo_sgd.zo2 import MeZO2SGD\n", "from zo2.config.mezo_sgd import MeZOSGDConfig\n", "from zo2.utils import seed_everything" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define Your Model \n", "Here we use NanoGPT model, copied from [nanogpt github](https://github.com/karpathy/build-nanogpt/blob/master/train_gpt2.py), as an example." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class GPTConfig:\n", " block_size: int = 1024\n", " vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency\n", " n_layer: int = 12\n", " n_head: int = 12\n", " n_embd: int = 768\n", " dropout: float = 0.0\n", " bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster\n", "\n", "class LayerNorm(nn.Module):\n", " \"\"\" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False \"\"\"\n", "\n", " def __init__(self, ndim, bias):\n", " super().__init__()\n", " self.weight = nn.Parameter(torch.ones(ndim))\n", " self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None\n", "\n", " def forward(self, input):\n", " return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)\n", "\n", "class CausalSelfAttention(nn.Module):\n", "\n", " def __init__(self, config):\n", " super().__init__()\n", " assert config.n_embd % config.n_head == 0\n", " # key, query, value projections for all heads, but in a batch\n", " self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)\n", " # output projection\n", " self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)\n", " # regularization\n", " self.attn_dropout = nn.Dropout(config.dropout)\n", " self.resid_dropout = nn.Dropout(config.dropout)\n", " self.n_head = config.n_head\n", " self.n_embd = config.n_embd\n", " self.dropout = config.dropout\n", " # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0\n", " self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')\n", " if not self.flash:\n", " print(\"WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0\")\n", " # causal mask to ensure that attention is only applied to the left in the input sequence\n", " self.register_buffer(\"bias\", torch.tril(torch.ones(config.block_size, config.block_size))\n", " .view(1, 1, config.block_size, config.block_size))\n", "\n", " def forward(self, x):\n", " B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)\n", "\n", " # calculate query, key, values for all heads in batch and move head forward to be the batch dim\n", " q, k, v = self.c_attn(x).split(self.n_embd, dim=2)\n", " k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)\n", " q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)\n", " v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)\n", "\n", " # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)\n", " if self.flash:\n", " # efficient attention using Flash Attention CUDA kernels\n", " y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)\n", " else:\n", " # manual implementation of attention\n", " att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))\n", " att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))\n", " att = F.softmax(att, dim=-1)\n", " att = self.attn_dropout(att)\n", " y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)\n", " y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side\n", "\n", " # output projection\n", " y = self.resid_dropout(self.c_proj(y))\n", " return y\n", "\n", "class MLP(nn.Module):\n", "\n", " def __init__(self, config):\n", " super().__init__()\n", " self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)\n", " self.gelu = nn.GELU()\n", " self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)\n", " self.dropout = nn.Dropout(config.dropout)\n", "\n", " def forward(self, x):\n", " x = self.c_fc(x)\n", " x = self.gelu(x)\n", " x = self.c_proj(x)\n", " x = self.dropout(x)\n", " return x\n", "\n", "class Block(nn.Module):\n", "\n", " def __init__(self, config):\n", " super().__init__()\n", " self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)\n", " self.attn = CausalSelfAttention(config)\n", " self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)\n", " self.mlp = MLP(config)\n", "\n", " def forward(self, x):\n", " x = x + self.attn(self.ln_1(x))\n", " x = x + self.mlp(self.ln_2(x))\n", " return x\n", "\n", "\n", "class GPT(nn.Module):\n", "\n", " def __init__(self, config):\n", " super().__init__()\n", " assert config.vocab_size is not None\n", " assert config.block_size is not None\n", " self.config = config\n", "\n", " self.transformer = nn.ModuleDict(dict(\n", " wte = nn.Embedding(config.vocab_size, config.n_embd),\n", " wpe = nn.Embedding(config.block_size, config.n_embd),\n", " drop = nn.Dropout(config.dropout),\n", " h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),\n", " ln_f = LayerNorm(config.n_embd, bias=config.bias),\n", " ))\n", " self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n", " # with weight tying when using torch.compile() some warnings get generated:\n", " # \"UserWarning: functional_call was passed multiple values for tied weights.\n", " # This behavior is deprecated and will be an error in future versions\"\n", " # not 100% sure what this is, so far seems to be harmless. TODO investigate\n", " # self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying\n", "\n", " # init all weights\n", " self.apply(self._init_weights)\n", " # apply special scaled init to the residual projections, per GPT-2 paper\n", " for pn, p in self.named_parameters():\n", " if pn.endswith('c_proj.weight'):\n", " torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))\n", "\n", " # report number of parameters\n", " print(\"number of parameters: %.2fM\" % (self.get_num_params()/1e6,))\n", "\n", " def get_num_params(self, non_embedding=True):\n", " \"\"\"\n", " Return the number of parameters in the model.\n", " For non-embedding count (default), the position embeddings get subtracted.\n", " The token embeddings would too, except due to the parameter sharing these\n", " params are actually used as weights in the final layer, so we include them.\n", " \"\"\"\n", " n_params = sum(p.numel() for p in self.parameters())\n", " if non_embedding:\n", " n_params -= self.transformer.wpe.weight.numel()\n", " return n_params\n", "\n", " def _init_weights(self, module):\n", " if isinstance(module, nn.Linear):\n", " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n", " if module.bias is not None:\n", " torch.nn.init.zeros_(module.bias)\n", " elif isinstance(module, nn.Embedding):\n", " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n", "\n", " def forward(self, idx, pos, targets=None):\n", " # idx is of shape (B, T)\n", " B, T = idx.size()\n", " assert T <= self.config.block_size, f\"Cannot forward sequence of length {T}, block size is only {self.config.block_size}\"\n", " # forward the token and posisition embeddings\n", " pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)\n", " tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)\n", " x = tok_emb + pos_emb\n", " # forward the blocks of the transformer\n", " for block in self.transformer.h:\n", " x = block(x)\n", " # forward the final layernorm and the classifier\n", " x = self.transformer.ln_f(x)\n", " logits = self.lm_head(x) # (B, T, vocab_size)\n", " loss = None\n", " if targets is not None:\n", " loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))\n", " return logits, loss\n", "\n", " def crop_block_size(self, block_size):\n", " # model surgery to decrease the block size if necessary\n", " # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)\n", " # but want to use a smaller block size for some smaller, simpler model\n", " assert block_size <= self.config.block_size\n", " self.config.block_size = block_size\n", " self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])\n", " for block in self.transformer.h:\n", " if hasattr(block.attn, 'bias'):\n", " block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]\n", "\n", " @classmethod\n", " def from_pretrained(cls, model_type, override_args=None):\n", " assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}\n", " override_args = override_args or {} # default to empty dict\n", " # only dropout can be overridden see more notes below\n", " assert all(k == 'dropout' for k in override_args)\n", " from transformers import GPT2LMHeadModel\n", " print(\"loading weights from pretrained gpt: %s\" % model_type)\n", "\n", " # n_layer, n_head and n_embd are determined from model_type\n", " config_args = {\n", " 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params\n", " 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params\n", " 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params\n", " 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params\n", " }[model_type]\n", " print(\"forcing vocab_size=50257, block_size=1024, bias=True\")\n", " config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints\n", " config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints\n", " config_args['bias'] = True # always True for GPT model checkpoints\n", " # we can override the dropout rate, if desired\n", " if 'dropout' in override_args:\n", " print(f\"overriding dropout rate to {override_args['dropout']}\")\n", " config_args['dropout'] = override_args['dropout']\n", " # create a from-scratch initialized minGPT model\n", " config = GPTConfig(**config_args)\n", " model = GPT(config)\n", " sd = model.state_dict()\n", " sd_keys = sd.keys()\n", " sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param\n", "\n", " # init a huggingface/transformers model\n", " model_hf = GPT2LMHeadModel.from_pretrained(model_type)\n", " sd_hf = model_hf.state_dict()\n", "\n", " # copy while ensuring all of the parameters are aligned and match in names and shapes\n", " sd_keys_hf = sd_hf.keys()\n", " sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer\n", " sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)\n", " transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']\n", " # basically the openai checkpoints use a \"Conv1D\" module, but we only want to use a vanilla Linear\n", " # this means that we have to transpose these weights when we import them\n", " assert len(sd_keys_hf) == len(sd_keys), f\"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}\"\n", " for k in sd_keys_hf:\n", " if any(k.endswith(w) for w in transposed):\n", " # special treatment for the Conv1D weights we need to transpose\n", " assert sd_hf[k].shape[::-1] == sd[k].shape\n", " with torch.no_grad():\n", " sd[k].copy_(sd_hf[k].t())\n", " else:\n", " # vanilla copy over the other parameters\n", " assert sd_hf[k].shape == sd[k].shape\n", " with torch.no_grad():\n", " sd[k].copy_(sd_hf[k])\n", "\n", " return model\n", "\n", " def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):\n", " # start with all of the candidate parameters\n", " param_dict = {pn: p for pn, p in self.named_parameters()}\n", " # filter out those that do not require grad\n", " param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}\n", " # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.\n", " # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.\n", " decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]\n", " nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]\n", " optim_groups = [\n", " {'params': decay_params, 'weight_decay': weight_decay},\n", " {'params': nodecay_params, 'weight_decay': 0.0}\n", " ]\n", " num_decay_params = sum(p.numel() for p in decay_params)\n", " num_nodecay_params = sum(p.numel() for p in nodecay_params)\n", " print(f\"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters\")\n", " print(f\"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters\")\n", " # Create AdamW optimizer and use the fused version if it is available\n", " fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters\n", " use_fused = fused_available and device_type == 'cuda'\n", " extra_args = dict(fused=True) if use_fused else dict()\n", " optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)\n", " print(f\"using fused AdamW: {use_fused}\")\n", "\n", " return optimizer\n", "\n", " def estimate_mfu(self, fwdbwd_per_iter, dt):\n", " \"\"\" estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS \"\"\"\n", " # first estimate the number of flops we do per iteration.\n", " # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311\n", " N = self.get_num_params()\n", " cfg = self.config\n", " L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size\n", " flops_per_token = 6*N + 12*L*H*Q*T\n", " flops_per_fwdbwd = flops_per_token * T\n", " flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter\n", " flops_achieved = flops_per_iter * (1.0/dt) # per second\n", " flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS\n", " mfu = flops_achieved / flops_promised\n", " return mfu\n", "\n", " @torch.no_grad()\n", " def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):\n", " \"\"\"\n", " Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete\n", " the sequence max_new_tokens times, feeding the predictions back into the model each time.\n", " Most likely you'll want to make sure to be in model.eval() mode of operation for this.\n", " \"\"\"\n", " for _ in range(max_new_tokens):\n", " idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]\n", " logits, _ = self(idx_cond)\n", " logits = logits[:, -1, :] / temperature\n", " if top_k is not None:\n", " v, _ = torch.topk(logits, min(top_k, logits.size(-1)))\n", " logits[logits < v[:, [-1]]] = -float('Inf')\n", " probs = F.softmax(logits, dim=-1)\n", " idx_next = torch.multinomial(probs, num_samples=1)\n", " idx = torch.cat((idx, idx_next), dim=1)\n", "\n", " return idx" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Apply ZO to NanoGPT" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# create a ZO optimizer\n", "class Optimizer(MeZOSGD):\n", "\n", " @torch.inference_mode\n", " def inner_zo_forward(self, idx, pos, targets):\n", " tok_emb = self.model.transformer.wte(idx)\n", " pos_emb = self.model.transformer.wpe(pos)\n", " x = tok_emb + pos_emb\n", " for block in self.model.transformer.h:\n", " x = block(x)\n", " x = self.model.transformer.ln_f(x)\n", " x = self.model.lm_head(x)\n", " loss = F.cross_entropy(\n", " x.reshape(-1, x.size(-1)), \n", " targets.reshape(-1)\n", " )\n", " return loss.detach()\n", "\n", " @torch.inference_mode() \n", " def inner_zo_eval_forward(self, eval_fn, idx, pos, targets):\n", " output = eval_fn(idx, pos, targets)\n", " return output" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# fused the ZO optimizer into model\n", "class ZOGPT(GPT, BaseZOModel):\n", " def __init__(self, config: GPTConfig, zo_config: MeZOSGDConfig):\n", " super().__init__(config)\n", " self.opt = Optimizer(model=self, config=zo_config)\n", "\n", " def forward(self, idx, pos, targets=None):\n", " if self.zo_training:\n", " return self.opt.zo_forward(idx, pos, targets)\n", " else:\n", " # for evaluate and inference purpose\n", " return self.opt.zo_eval_forward(super().forward, idx, pos, targets)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Apply ZO2 to NanoGPT" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# create a ZO2 optimizer\n", "class Optimizer(MeZO2SGD):\n", " \n", " def init_zo2_upload(self):\n", " print(\"Upload head and tail to cuda.\")\n", " self.model.transformer.wte = self.model.transformer.wte.to(self.device)\n", " self.model.transformer.wpe = self.model.transformer.wpe.to(self.device)\n", " self.model.transformer.ln_f = self.model.transformer.ln_f.to(self.device)\n", " self.model.lm_head = self.model.lm_head.to(self.device)\n", " \n", " self.num_blocks = len(self.model.transformer.h)\n", " if self.offloading_blocks is not None:\n", " self.offloading_blocks = self.offloading_blocks\n", " else:\n", " self.offloading_blocks = list(range(self.num_blocks))\n", " print(f\"Transformer blocks {self.offloading_blocks} will be offloaded to {self.offloading_device}\")\n", " for i in range(self.num_blocks):\n", " if i in self.offloading_blocks:\n", " continue\n", " else:\n", " self.model.transformer.h[i] = self.model.transformer.h[i].to(self.device)\n", " print(f\"Upload block {i} to cuda.\")\n", "\n", " @torch.inference_mode() \n", " def inner_zo_forward(self, idx, pos, targets):\n", " we1, we2 = self.task_compute_module(self.model.transformer.wte,\n", " inputs1={\"input\": idx},\n", " inputs2={\"input\": idx},\n", " grad=self.projected_grad)\n", " # only sync the compute stream at the first compute task\n", " pe1, pe2 = self.task_compute_module(self.model.transformer.wpe, \n", " {\"input\": pos}, \n", " {\"input\": pos}, \n", " self.projected_grad,\n", " compute_sync=False) \n", " # disable the compute stream sync because we want all the compute tasks overlap with the following upload task\n", " hidden_states1, hidden_states2 = self.task_compute_function(torch.add,\n", " {\"input\": we1, \"other\": pe1},\n", " {\"input\": we2, \"other\": pe2},\n", " compute_sync=False)\n", " if 0 in self.offloading_blocks:\n", " self.model.transformer.h[0] = self.task_upload(\n", " module=self.model.transformer.h[0], \n", " device=self.device)\n", " N = len(self.model.transformer.h)\n", " for i in range(1, N):\n", " # follow the rule that do offload the i-2-th block, compute the i-1-th block, and upload the i-th block in order.\n", " if i != 1:\n", " if i-2 in self.offloading_blocks:\n", " self.model.transformer.h[i-2] = self.task_offload(\n", " module=self.model.transformer.h[i-2], \n", " device=self.offloading_device)\n", " hidden_states1, hidden_states2 = self.task_compute_module(\n", " self.model.transformer.h[i-1], \n", " inputs1={\"x\": hidden_states1}, \n", " inputs2={\"x\": hidden_states2}, \n", " grad=self.projected_grad)\n", " if i in self.offloading_blocks:\n", " self.model.transformer.h[i] = self.task_upload(\n", " module=self.model.transformer.h[i], \n", " device=self.device)\n", " if N-2 in self.offloading_blocks:\n", " self.model.transformer.h[N-2] = self.task_offload(\n", " self.model.transformer.h[N-2], device=self.offloading_device)\n", " hidden_states1, hidden_states2 = self.task_compute_module(\n", " self.model.transformer.h[N-1], \n", " inputs1={\"x\": hidden_states1}, \n", " inputs2={\"x\": hidden_states2}, \n", " grad=self.projected_grad\n", " )\n", " if N-1 in self.offloading_blocks:\n", " self.model.transformer.h[N-1] = self.task_offload(\n", " self.model.transformer.h[N-1], device=self.offloading_device)\n", " logits1, logits2 = self.task_compute_module(self.model.transformer.ln_f,\n", " inputs1={\"input\": hidden_states1}, \n", " inputs2={\"input\": hidden_states2}, \n", " grad=self.projected_grad,\n", " weight_decay=0.) \n", " # 'task_compute_module' will remove the first name 'ln_f', so we need to disable weight_decay manually.\n", " logits1, logits2 = self.task_compute_module(self.model.lm_head,\n", " inputs1={\"input\": logits1}, \n", " inputs2={\"input\": logits2}, \n", " grad=self.projected_grad)\n", " loss1, loss2 = self.task_compute_function(F.cross_entropy,\n", " {\"input\": logits1.reshape(-1, logits1.size(-1)), \n", " \"target\": targets.reshape(-1)},\n", " {\"input\": logits2.reshape(-1, logits2.size(-1)), \n", " \"target\": targets.reshape(-1)})\n", " return loss1, loss2\n", " \n", " @torch.inference_mode() \n", " def inner_zo_eval_forward(self, eval_fn, idx, pos, targets):\n", " handles = self.add_zo2_eval_comm_hooks(self.model.transformer.h)\n", " # You can add zo2_eval_comm_hooks to all transformer blocks,\n", " # but may be slower.\n", " output = eval_fn(idx, pos, targets)\n", " self.clear_zo2_eval_comm_hooks(handles)\n", " return output" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# fused the ZO optimizer into model\n", "class ZO2GPT(GPT, BaseZOModel):\n", " def __init__(self, config: GPTConfig, zo_config: MeZOSGDConfig):\n", " super().__init__(config)\n", " self.opt = Optimizer(model=self, config=zo_config)\n", "\n", " def forward(self, idx, pos, targets=None):\n", " if self.zo_training:\n", " return self.opt.zo_forward(idx, pos, targets)\n", " else:\n", " # for evaluate and inference purpose\n", " return self.opt.zo_eval_forward(super().forward, idx, pos, targets)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train the Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Hyperparameter\n", "zo_method = \"zo2\"\n", "eval_mode = False\n", "model_name = \"gpt2_xl\"\n", "verbose = True\n", "max_steps = 100\n", "learning_rate = 1e-4\n", "batch_size = 1\n", "weight_decay = 1e-1\n", "zo_eps = 1e-3\n", "seed = 42\n", "offloading_device = \"cpu\"\n", "working_device = \"cuda:0\"\n", "use_cache = True\n", "max_new_tokens = 50\n", "temperature = 1.0\n", "seed_everything(seed)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# ZO steps\n", "zo_config = ZOConfig(\n", " method=\"mezo-sgd\", \n", " zo2=zo_method==\"zo2\", \n", " lr=learning_rate,\n", " weight_decay=weight_decay,\n", " eps=zo_eps,\n", " offloading_device=offloading_device,\n", " working_device=working_device,\n", ")\n", "\n", "# Load ZO model\n", "class GPTConfigs:\n", " gpt2: GPTConfig = GPTConfig(n_layer=12, n_head=12, n_embd=768)\n", " gpt2_medium: GPTConfig = GPTConfig(n_layer=24, n_head=16, n_embd=1024)\n", " gpt2_large: GPTConfig = GPTConfig(n_layer=36, n_head=20, n_embd=1280)\n", " gpt2_xl: GPTConfig = GPTConfig(n_layer=48, n_head=25, n_embd=1600)\n", "cfgs = GPTConfigs()\n", "model_cfg = getattr(cfgs, model_name)\n", "MODEL_CLASS = ZO2GPT if zo_method==\"zo2\" else ZOGPT\n", "model = MODEL_CLASS(config=model_cfg, zo_config=zo_config)\n", "if zo_method != \"zo2\": \n", " model = model.to(working_device)\n", "print(f\"Check if zo2 init correctly: {hasattr(model, 'zo_training')}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# prepare some data (random generated)\n", "B, V, T = batch_size, model_cfg.vocab_size, model_cfg.block_size\n", "data_batch = torch.randint(0, V, (B, T+1)).to(working_device)\n", "input_ids = data_batch[:, :T] # shift data and labels\n", "labels = data_batch[:, 1:T+1]\n", "pos = torch.arange(input_ids.shape[1], dtype=torch.long, device=working_device).unsqueeze(0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# training loop\n", "for i in tqdm(range(max_steps)):\n", " model.zo_train()\n", " loss = model(input_ids, pos, labels)\n", " res = \"Iteration {}, loss: {}, projected grad: {}\"\n", " tqdm.write(res.format(i, loss, model.opt.projected_grad))" ] } ], "metadata": { "kernelspec": { "display_name": "mezo", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: zo2/README.md ================================================ # Core code of ZO2 ## Features 1. Fuse model dual-forward and optimizer step into model forward code. For example, ```python # first-order one training step: model.train() loss = model(input, label) # forward loss.backward() # backward optimizer.step() # update parameters, optimizer states # zo2 one training step: model.zo_train() # Enable zo training loss = model(input, label) # fuse dual-forward, parameters and optimizer states updates ``` ## Code Logic 1. Fuse model dual-forward and optimizer step into model forward code. ## In progress... ================================================ FILE: zo2/__init__.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 # configs from .config import ZOConfig # model from .model.nanogpt.mezo_sgd import get_nanogpt_mezo_sgd from .model.huggingface.zo_init import zo_hf_init from .model.huggingface.opt import ( get_opt_for_causalLM, get_opt_for_sequence_classification, get_opt_for_question_answering ) ================================================ FILE: zo2/config/__init__.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 from .mezo_sgd import MeZOSGDConfig def ZOConfig(method: str = "mezo-sgd", **kwargs): match method: case "mezo-sgd": return MeZOSGDConfig(**kwargs) # case "another-method": # return AnotherConfig(**kwargs) case _: raise ValueError(f"Unsupported method {method}") ================================================ FILE: zo2/config/mezo_sgd.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import torch from dataclasses import dataclass @dataclass class MeZOSGDConfig: # zo method zo_method: str = "mezo-sgd" # zo method name, every zo config must include this attribute # zo config lr: float = 1e-3 weight_decay: float = 1e-1 eps: float = 1e-3 max_zo_random_seed = 1000000000 # zo2 config zo2: bool = True # use offloading or not offloading_blocks: list = None # specify offloading blocks or not offloading_device: str = 'cpu' # offload device, can be CPU or a path (for disk offloading, but currently unavailable) working_device: str = 'cuda' # compute device, can be any CUDA device overlap: bool = True # use scheduler to overlap or not compute_module_optimize_method: str = '' # possible values are: ['', 'torch.compile'] compute_function_optimize_method: str = '' # possible values are: ['', 'torch.jit.script'] communicate_optimize_method: str = '' # possible values are: ['', 'bucket'] amp: bool = False # use amp or not amp_precision: torch.dtype = torch.bfloat16 # amp autocast precision, possible values are: [torch.bfloat16, torch.float32], valid when using amp precision_on_offloading_device: torch.dtype = torch.float16 # precision on offloading device, valid when using amp precision_on_working_device: torch.dtype = torch.float32 # precision on working device, valid when using amp amp_compress_method: str = 'naive' # currently only support naive amp compress, valid when using amp # debug debug_mode: bool = False # set 'True' to disable random noise ================================================ FILE: zo2/model/__init__.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 ================================================ FILE: zo2/model/base.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import torch class BaseZOModel(torch.nn.Module): def __init__(self): super().__init__() self.zo_training = True self.zo_train_loss_fn_pre_hooks = [] self.zo_train_loss_fn_post_hooks = [] self.zo_eval_loss_fn_pre_hooks = [] self.zo_eval_loss_fn_post_hooks = [] self.zo_custom_train_loss_fn = None self.zo_custom_eval_loss_fn = None def zo_train(self): """ Zeroth-order training """ self.zo_training = True self.eval() def zo_eval(self): """ Zeroth-order evaluation """ self.zo_training = False self.eval() def register_zo_train_loss_fn_pre_hook(self, hook_fn): self.zo_train_loss_fn_pre_hooks.append(hook_fn) def register_zo_train_loss_fn_post_hook(self, hook_fn): self.zo_train_loss_fn_post_hooks.append(hook_fn) def register_zo_eval_loss_fn_pre_hook(self, hook_fn): self.zo_eval_loss_fn_pre_hooks.append(hook_fn) def register_zo_eval_loss_fn_post_hook(self, hook_fn): self.zo_eval_loss_fn_post_hooks.append(hook_fn) def register_custom_opt(self, custom_opt_obj): if hasattr(self, "opt"): self.opt = custom_opt_obj for module in self.children(): if hasattr(module, "opt"): module.opt = custom_opt_obj ================================================ FILE: zo2/model/huggingface/__init__.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 ================================================ FILE: zo2/model/huggingface/gpt/mezo_sgd/zo.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 ================================================ FILE: zo2/model/huggingface/gpt/mezo_sgd/zo2.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 ================================================ FILE: zo2/model/huggingface/llama/mezo_sgd/zo.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import torch import torch.nn as nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.models.llama import modeling_llama ================================================ FILE: zo2/model/huggingface/llama/mezo_sgd/zo2.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import torch import torch.nn as nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.models.llama import modeling_llama ================================================ FILE: zo2/model/huggingface/opt/__init__.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 from . import ( mezo_sgd, ) def get_opt_for_causalLM(zo_config): zo2_supported_configs = { "mezo-sgd": mezo_sgd.get_opt_for_causalLM_mezo_sgd, } return zo2_supported_configs[zo_config.zo_method](zo_config) def get_opt_for_sequence_classification(zo_config): zo2_supported_configs = { "mezo-sgd": mezo_sgd.get_opt_for_sequence_classification_mezo_sgd, } return zo2_supported_configs[zo_config.zo_method](zo_config) def get_opt_for_question_answering(zo_config): zo2_supported_configs = { "mezo-sgd": mezo_sgd.get_opt_for_question_answering_mezo_sgd, } return zo2_supported_configs[zo_config.zo_method](zo_config) ================================================ FILE: zo2/model/huggingface/opt/mezo_sgd/__init__.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 from . import zo, zo2 from .....config.mezo_sgd import MeZOSGDConfig def get_opt_for_causalLM_mezo_sgd(config: MeZOSGDConfig): return zo2.OPTForCausalLM if config.zo2 else zo.OPTForCausalLM def get_opt_for_sequence_classification_mezo_sgd(config: MeZOSGDConfig): return zo2.OPTForSequenceClassification if config.zo2 else zo.OPTForSequenceClassification def get_opt_for_question_answering_mezo_sgd(config: MeZOSGDConfig): return zo2.OPTForQuestionAnswering if config.zo2 else zo.OPTForQuestionAnswering ================================================ FILE: zo2/model/huggingface/opt/mezo_sgd/utils.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import torch def fn_get_opt_decoder_hidden_states_from_layer_outputs(input): return input[0] def get_shift_logits(logits): return logits[..., :-1, :].contiguous() def get_shift_labels(labels): return labels[..., 1:].contiguous() def get_pooled_logits(logits, batch_size, sequence_lengths): return logits[torch.arange(batch_size, device=logits.device), sequence_lengths] def get_start_logits_and_end_logits(logits): start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1).contiguous() end_logits = end_logits.squeeze(-1).contiguous() return start_logits, end_logits def get_qa_loss(loss_fct, start_logits, start_positions, end_logits, end_positions): start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 return total_loss def init_all_hidden_states(output_hidden_states): return () if output_hidden_states else None def init_all_self_attns(output_attentions): return () if output_attentions else None def init_next_decoder_cache(use_cache): return () if use_cache else None def update_next_decoder_cache(use_cache, next_decoder_cache, layer_outputs, output_attentions): if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) return next_decoder_cache def update_all_self_attns(output_attentions, all_self_attns, layer_outputs): if output_attentions: all_self_attns += (layer_outputs[1],) return all_self_attns def update_all_hidden_states(output_hidden_states, all_hidden_states, hidden_states): if output_hidden_states: all_hidden_states += (hidden_states,) return all_hidden_states def get_past_key_value(past_key_values, idx): return past_key_values[idx] if past_key_values is not None else None def get_opt_sequence_classification_pooled_logits(self, logits, input_ids, inputs_embeds): if input_ids is not None: batch_size, sequence_length = input_ids.shape[:2] else: batch_size, sequence_length = inputs_embeds.shape[:2] if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) else: sequence_lengths = -1 return logits[torch.arange(batch_size, device=logits.device), sequence_lengths] def get_opt_sequence_classification_loss(self, loss, pooled_logits, labels): if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = torch.nn.MSELoss() if self.num_labels == 1: loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) else: loss = loss_fct(pooled_logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = torch.nn.CrossEntropyLoss() loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = torch.nn.BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) return loss def get_opt_question_answering_start_end_logits(logits): start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1).contiguous() end_logits = end_logits.squeeze(-1).contiguous() return start_logits, end_logits def get_opt_question_answering_loss(total_loss, start_logits, start_positions, end_logits, end_positions): # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions = start_positions.clamp(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index) loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 return total_loss ================================================ FILE: zo2/model/huggingface/opt/mezo_sgd/zo.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import torch import torch.nn as nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.models.opt import modeling_opt from transformers.models.opt.modeling_opt import ( OPTConfig, OPTPreTrainedModel, OPTLearnedPositionalEmbedding, OPTDecoderLayer, BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, QuestionAnsweringModelOutput, add_start_docstrings_to_model_forward, add_code_sample_docstrings, replace_return_docstrings, OPT_INPUTS_DOCSTRING, _CHECKPOINT_FOR_DOC, _CONFIG_FOR_DOC, _EXPECTED_OUTPUT_SHAPE, _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, _SEQ_CLASS_EXPECTED_OUTPUT, _SEQ_CLASS_EXPECTED_LOSS, ) from transformers.utils import logging import random from typing import List, Optional, Tuple, Union from ....base import BaseZOModel from .....optimizer.mezo_sgd.zo import MeZOSGD from .....config.mezo_sgd import MeZOSGDConfig logger = logging.get_logger(__name__) class OPTDecoder(modeling_opt.OPTDecoder, OPTPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`] Args: config: OPTConfig """ def __init__(self, config: OPTConfig): """ !!! Module register must follow the execution order. """ OPTPreTrainedModel.__init__(self, config) self.dropout = config.dropout self.layerdrop = config.layerdrop self.padding_idx = config.pad_token_id self.max_target_positions = config.max_position_embeddings self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx) self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size) if config.word_embed_proj_dim != config.hidden_size: self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False) else: self.project_in = None self.layers = nn.ModuleList([OPTDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility # with checkpoints that have been fine-tuned before transformers v4.20.1 # see https://github.com/facebookresearch/metaseq/pull/164 if config.do_layer_norm_before and not config._remove_final_layer_norm: self.final_layer_norm = nn.LayerNorm( config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine ) else: self.final_layer_norm = None if config.word_embed_proj_dim != config.hidden_size: self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) else: self.project_out = None self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() class OPTModel(modeling_opt.OPTModel, OPTPreTrainedModel): def __init__(self, config: OPTConfig): OPTPreTrainedModel.__init__(self, config) self.decoder = OPTDecoder(config) # Initialize weights and apply final processing self.post_init() class OPTForCausalLM(modeling_opt.OPTForCausalLM, OPTPreTrainedModel, BaseZOModel): _keys_to_ignore_on_load_missing = [r"lm_head.weight"] def __init__(self, config: OPTConfig): OPTPreTrainedModel.__init__(self, config) BaseZOModel.__init__(self) self.model = OPTModel(config) # the lm_head weight is automatically tied to the embed tokens weight self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def zo_init(self, zo_config): self.opt = OptimizerOPTForCausalLM(model=self, config=zo_config) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, position_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.Tensor] = None, **kwargs ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are only required when the model is used as a decoder in a Sequence to Sequence model. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Returns: Example: ```python >>> from transformers import AutoTokenizer, OPTForCausalLM >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") >>> prompt = "Hey, are you consciours? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." ```""" if self.zo_training: return self.opt.zo_forward( input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, position_ids, cache_position, **kwargs) else: return self.opt.zo_eval_forward(super().forward, input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, position_ids, cache_position, **kwargs) class OPTForSequenceClassification(modeling_opt.OPTForSequenceClassification, OPTPreTrainedModel, BaseZOModel): def __init__(self, config: OPTConfig): OPTPreTrainedModel.__init__(self, config) BaseZOModel.__init__(self) self.num_labels = config.num_labels self.model = OPTModel(config) self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() def zo_init(self, zo_config): self.opt = OptimizerOPTForSequenceClassification(model=self, config=zo_config) @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, output_type=SequenceClassifierOutputWithPast, config_class=_CONFIG_FOR_DOC, expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, expected_loss=_SEQ_CLASS_EXPECTED_LOSS, ) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ if self.zo_training: return self.opt.zo_forward( input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict) else: return self.opt.zo_eval_forward(super().forward, input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict) class OPTForQuestionAnswering(modeling_opt.OPTForQuestionAnswering, OPTPreTrainedModel, BaseZOModel): def __init__(self, config: OPTConfig): OPTPreTrainedModel.__init__(self, config) BaseZOModel.__init__(self) self.model = OPTModel(config) self.qa_outputs = nn.Linear(config.word_embed_proj_dim, 2) # Initialize weights and apply final processing self.post_init() def zo_init(self, zo_config): self.opt = OptimizerOPTForQuestionAnswering(model=self, config=zo_config) @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, start_positions: Optional[torch.LongTensor] = None, end_positions: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, QuestionAnsweringModelOutput]: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. Returns: Example: ```python >>> from transformers import AutoTokenizer, OPTForQuestionAnswering >>> import torch >>> torch.manual_seed(4) # doctest: +IGNORE_RESULT >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") >>> # note: we are loading a OPTForQuestionAnswering from the hub here, >>> # so the head will be randomly initialized, hence the predictions will be random >>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m") >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" >>> inputs = tokenizer(question, text, return_tensors="pt") >>> with torch.no_grad(): ... outputs = model(**inputs) >>> answer_start_index = outputs.start_logits.argmax() >>> answer_end_index = outputs.end_logits.argmax() >>> answer_offset = len(tokenizer(question)[0]) >>> predict_answer_tokens = inputs.input_ids[ ... 0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1 ... ] >>> predicted = tokenizer.decode(predict_answer_tokens) >>> predicted ' a nice puppet' ```""" if self.zo_training: return self.opt.zo_forward( input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, start_positions, end_positions, use_cache, output_attentions, output_hidden_states, return_dict) else: return self.opt.zo_eval_forward(super().forward, input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, start_positions, end_positions, use_cache, output_attentions, output_hidden_states, return_dict) class OptimizerOPTForCausalLM(MeZOSGD): @torch.inference_mode def inner_zo_forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, position_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.Tensor] = None, **kwargs ) -> Union[Tuple, CausalLMOutputWithPast]: """ copy the original forward code and replace all 'self' to 'self.model'. """ output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model.model.decoder( input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, position_ids=position_ids, cache_position=cache_position, ) logits = self.model.lm_head(outputs[0]).contiguous() if self.model.zo_train_loss_fn_pre_hooks != []: for pre_hook_fn in self.model.zo_train_loss_fn_pre_hooks: input_ids, logits, labels = pre_hook_fn(self.model, input_ids, logits, labels) loss = None if self.model.zo_custom_train_loss_fn: loss = self.model.zo_custom_train_loss_fn(self.model, input_ids, logits, labels, **kwargs) elif labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, self.model.config.vocab_size), shift_labels.view(-1)) if self.model.zo_train_loss_fn_post_hooks != []: for post_hook_fn in self.model.zo_train_loss_fn_post_hooks: loss, input_ids, logits, labels = post_hook_fn(self.model, loss, input_ids, logits, labels) # add --> only return loss return loss.detach() @torch.inference_mode() def inner_zo_eval_forward( self, eval_fn, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, position_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.Tensor] = None, **kwargs ) -> Union[Tuple, CausalLMOutputWithPast]: if self.model.zo_eval_loss_fn_pre_hooks != []: for pre_hook_fn in self.model.zo_eval_loss_fn_pre_hooks: input_ids, logits, labels = pre_hook_fn(self.model, input_ids, logits, labels) if self.model.zo_custom_eval_loss_fn: output = eval_fn(input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, None, use_cache, output_attentions, output_hidden_states, return_dict) if not return_dict: logits = output[0] loss = self.model.zo_custom_eval_loss_fn(self.model, input_ids, logits, labels, **kwargs) output = (logits,) + output[1] return (loss,) + output if loss is not None else output logits = output["logits"] loss = self.model.zo_custom_eval_loss_fn(self.model, input_ids, logits, labels, **kwargs) output = CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=output["past_key_values"], hidden_states=output["hidden_states"], attentions=output["attentions"], ) else: output = eval_fn(input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, position_ids, cache_position) if self.model.zo_eval_loss_fn_post_hooks != []: for post_hook_fn in self.model.zo_eval_loss_fn_post_hooks: output, input_ids, logits, labels = post_hook_fn(self.model, output, input_ids, logits, labels) return output class OptimizerOPTForSequenceClassification(MeZOSGD): @torch.inference_mode def inner_zo_forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs ) -> Union[Tuple, SequenceClassifierOutputWithPast]: """ copy the original forward code and replace all 'self' to 'self.model'. """ return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict transformer_outputs = self.model.model( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] logits = self.model.score(hidden_states) if input_ids is not None: batch_size, sequence_length = input_ids.shape[:2] else: batch_size, sequence_length = inputs_embeds.shape[:2] if self.model.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: sequence_lengths = (torch.ne(input_ids, self.model.config.pad_token_id).sum(-1) - 1).to(logits.device) else: sequence_lengths = -1 logger.warning( f"{self.model.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " "unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] if self.model.zo_train_loss_fn_pre_hooks != []: for pre_hook_fn in self.model.zo_train_loss_fn_pre_hooks: input_ids, logits, labels = pre_hook_fn(self.model, input_ids, logits, labels) loss = None if self.model.zo_custom_train_loss_fn: loss = self.model.zo_custom_train_loss_fn(self.model, input_ids, logits, labels, **kwargs) elif labels is not None: if self.model.config.problem_type is None: if self.model.num_labels == 1: self.model.config.problem_type = "regression" elif self.model.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.model.config.problem_type = "single_label_classification" else: self.model.config.problem_type = "multi_label_classification" if self.model.config.problem_type == "regression": loss_fct = MSELoss() if self.model.num_labels == 1: loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) else: loss = loss_fct(pooled_logits, labels) elif self.model.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(pooled_logits.view(-1, self.model.num_labels), labels.view(-1)) elif self.model.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) if self.model.zo_train_loss_fn_post_hooks != []: for post_hook_fn in self.model.zo_train_loss_fn_post_hooks: loss, input_ids, logits, labels = post_hook_fn(self.model, loss, input_ids, logits, labels) # add --> only return loss if self.model.zo_training: return loss.detach() @torch.inference_mode def inner_zo_eval_forward( self, eval_fn, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs ) -> Union[Tuple, SequenceClassifierOutputWithPast]: if self.model.zo_eval_loss_fn_pre_hooks != []: for pre_hook_fn in self.model.zo_eval_loss_fn_pre_hooks: input_ids, logits, labels = pre_hook_fn(self.model, input_ids, logits, labels) if self.model.zo_custom_eval_loss_fn: output = eval_fn(input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, None, use_cache, output_attentions, output_hidden_states, return_dict) if not return_dict: logits = output[0] loss = self.model.zo_custom_eval_loss_fn(self.model, input_ids, logits, labels, **kwargs) output = (logits,) + output[1] return (loss,) + output if loss is not None else output logits = output["logits"] loss = self.model.zo_custom_eval_loss_fn(self.model, input_ids, logits, labels, **kwargs) output = CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=output["past_key_values"], hidden_states=output["hidden_states"], attentions=output["attentions"], ) else: output = eval_fn(input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict) if self.model.zo_eval_loss_fn_post_hooks != []: for post_hook_fn in self.model.zo_eval_loss_fn_post_hooks: output, input_ids, logits, labels = post_hook_fn(self.model, output, input_ids, logits, labels) return output class OptimizerOPTForQuestionAnswering(MeZOSGD): @torch.inference_mode def inner_zo_forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, start_positions: Optional[torch.LongTensor] = None, end_positions: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs ) -> Union[Tuple, QuestionAnsweringModelOutput]: """ copy the original forward code and replace all 'self' to 'self.model'. """ return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict transformer_outputs = self.model.model( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] logits = self.model.qa_outputs(hidden_states) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1).contiguous() end_logits = end_logits.squeeze(-1).contiguous() if self.model.zo_train_loss_fn_pre_hooks != []: for pre_hook_fn in self.model.zo_train_loss_fn_pre_hooks: input_ids, start_logits, start_positions, end_logits, end_positions = \ pre_hook_fn(self.model, input_ids, start_logits, start_positions, end_logits, end_positions) total_loss = None if self.model.zo_custom_train_loss_fn: loss = self.model.zo_custom_train_loss_fn(self.model, input_ids, start_logits, start_positions, end_logits, end_positions, **kwargs) elif start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions = start_positions.clamp(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 if self.model.zo_train_loss_fn_post_hooks != []: for post_hook_fn in self.model.zo_train_loss_fn_post_hooks: loss, input_ids, start_logits, start_positions, end_logits, end_positions = \ post_hook_fn(self.model, loss, input_ids, start_logits, start_positions, end_logits, end_positions) # add --> only return loss if self.model.zo_training: return total_loss.detach() @torch.inference_mode def inner_zo_eval_forward( self, eval_fn, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, start_positions: Optional[torch.LongTensor] = None, end_positions: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs ) -> Union[Tuple, QuestionAnsweringModelOutput]: if self.model.zo_eval_loss_fn_pre_hooks != []: for pre_hook_fn in self.model.zo_eval_loss_fn_pre_hooks: input_ids, start_logits, start_positions, end_logits, end_positions = pre_hook_fn(self.model, input_ids, start_logits, start_positions, end_logits, end_positions) if self.model.zo_custom_eval_loss_fn: output = eval_fn(input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, None, None, use_cache, output_attentions, output_hidden_states, return_dict) if not return_dict: start_logits, end_logits = output[0], output[1] loss = self.model.zo_custom_eval_loss_fn(self.model, input_ids, start_logits, start_positions, end_logits, end_positions, **kwargs) output = (start_logits, end_logits) + output[2:] return (loss,) + output if loss is not None else output start_logits = output["start_logits"] end_logits = output["end_logits"] loss = self.model.zo_custom_eval_loss_fn(self.model, input_ids, start_logits, start_positions, end_logits, end_positions, **kwargs) output = QuestionAnsweringModelOutput( loss=loss, start_logits=start_logits, end_logits=end_logits, hidden_states=output["hidden_states"], attentions=output["attentions"], ) else: output = eval_fn(input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, start_positions, end_positions, use_cache, output_attentions, output_hidden_states, return_dict) if self.model.zo_eval_loss_fn_post_hooks != []: for post_hook_fn in self.model.zo_eval_loss_fn_post_hooks: output, input_ids, start_logits, start_positions, end_logits, end_positions = post_hook_fn(self.model, output, input_ids, start_logits, start_positions, end_logits, end_positions) return output ================================================ FILE: zo2/model/huggingface/opt/mezo_sgd/zo2.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import random import torch import torch.nn as nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.models.opt import modeling_opt from transformers.models.opt.modeling_opt import ( OPTConfig, OPTPreTrainedModel, OPTLearnedPositionalEmbedding, OPTDecoderLayer, BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, QuestionAnsweringModelOutput, add_start_docstrings_to_model_forward, add_code_sample_docstrings, replace_return_docstrings, OPT_INPUTS_DOCSTRING, _CHECKPOINT_FOR_DOC, _CONFIG_FOR_DOC, _EXPECTED_OUTPUT_SHAPE, _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, _SEQ_CLASS_EXPECTED_OUTPUT, _SEQ_CLASS_EXPECTED_LOSS, ) from transformers.utils import logging from typing import List, Optional, Tuple, Union from ....base import BaseZOModel from .....optimizer.mezo_sgd.zo2 import MeZO2SGD from .....config.mezo_sgd import MeZOSGDConfig from .utils import * logger = logging.get_logger(__name__) class OPTDecoder(modeling_opt.OPTDecoder, OPTPreTrainedModel, BaseZOModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`] Args: config: OPTConfig """ def __init__(self, config: OPTConfig): """ !!! Module register must follow the execution order. """ OPTPreTrainedModel.__init__(self, config) self.dropout = config.dropout self.layerdrop = config.layerdrop self.padding_idx = config.pad_token_id self.max_target_positions = config.max_position_embeddings self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx) self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size) if config.word_embed_proj_dim != config.hidden_size: self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False) else: self.project_in = None self.layers = nn.ModuleList([OPTDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility # with checkpoints that have been fine-tuned before transformers v4.20.1 # see https://github.com/facebookresearch/metaseq/pull/164 if config.do_layer_norm_before and not config._remove_final_layer_norm: self.final_layer_norm = nn.LayerNorm( config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine ) else: self.final_layer_norm = None if config.word_embed_proj_dim != config.hidden_size: self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) else: self.project_out = None self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def zo_init(self, zo_config): # Initialize ZO2 self.opt = OptimizerOPTDecoder(model=self, config=zo_config) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, position_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: if self.zo_training: return self.opt.inner_zo_forward(input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, position_ids, cache_position) else: return self.opt.zo_eval_forward(input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, position_ids, cache_position) class OPTModel(modeling_opt.OPTModel, OPTPreTrainedModel, BaseZOModel): def __init__(self, config: OPTConfig): OPTPreTrainedModel.__init__(self, config) BaseZOModel.__init__(self) self.decoder = OPTDecoder(config) # Initialize weights and apply final processing self.post_init() def zo_init(self, zo_config): self.decoder.zo_init(zo_config) # Initialize ZO2 self.opt = OptimizerOPTModel(model=self, config=zo_config) @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=BaseModelOutputWithPast, config_class=_CONFIG_FOR_DOC, expected_output=_EXPECTED_OUTPUT_SHAPE, ) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: if self.zo_training: return self.opt.inner_zo_forward(input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict) else: return self.opt.zo_eval_forward(input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict) class OPTForCausalLM(modeling_opt.OPTForCausalLM, OPTPreTrainedModel, BaseZOModel): _keys_to_ignore_on_load_missing = [r"lm_head.weight"] def __init__(self, config: OPTConfig): OPTPreTrainedModel.__init__(self, config) BaseZOModel.__init__(self) self.model = OPTModel(config) # the lm_head weight is automatically tied to the embed tokens weight self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def zo_init(self, zo_config): self.model.zo_init(zo_config) # Initialize ZO2 self.opt = OptimizerOPTForCausalLM(model=self, config=zo_config) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are only required when the model is used as a decoder in a Sequence to Sequence model. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Returns: Example: ```python >>> from transformers import AutoTokenizer, OPTForCausalLM >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") >>> prompt = "Hey, are you consciours? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." ```""" if self.zo_training: return self.opt.zo_forward( input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, **kwargs) else: return self.opt.zo_eval_forward( input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, **kwargs) class OPTForSequenceClassification(modeling_opt.OPTForSequenceClassification, OPTPreTrainedModel, BaseZOModel): def __init__(self, config: OPTConfig): OPTPreTrainedModel.__init__(self, config) BaseZOModel.__init__(self) self.num_labels = config.num_labels self.model = OPTModel(config) self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() def zo_init(self, zo_config): self.model.zo_init(zo_config) self.opt = OptimizerOPTForSequenceClassification(model=self, config=zo_config) @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, output_type=SequenceClassifierOutputWithPast, config_class=_CONFIG_FOR_DOC, expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, expected_loss=_SEQ_CLASS_EXPECTED_LOSS, ) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ if self.zo_training: return self.opt.zo_forward( input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, **kwargs) else: return self.opt.zo_eval_forward( input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, **kwargs) class OPTForQuestionAnswering(modeling_opt.OPTForQuestionAnswering, OPTPreTrainedModel, BaseZOModel): def __init__(self, config: OPTConfig): OPTPreTrainedModel.__init__(self, config) BaseZOModel.__init__(self) self.model = OPTModel(config) self.qa_outputs = nn.Linear(config.word_embed_proj_dim, 2) # Initialize weights and apply final processing self.post_init() def zo_init(self, zo_config): self.model.zo_init(zo_config) self.opt = OptimizerOPTForQuestionAnswering(model=self, config=zo_config) @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, start_positions: Optional[torch.LongTensor] = None, end_positions: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs ) -> Union[Tuple, QuestionAnsweringModelOutput]: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. Returns: Example: ```python >>> from transformers import AutoTokenizer, OPTForQuestionAnswering >>> import torch >>> torch.manual_seed(4) # doctest: +IGNORE_RESULT >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") >>> # note: we are loading a OPTForQuestionAnswering from the hub here, >>> # so the head will be randomly initialized, hence the predictions will be random >>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m") >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" >>> inputs = tokenizer(question, text, return_tensors="pt") >>> with torch.no_grad(): ... outputs = model(**inputs) >>> answer_start_index = outputs.start_logits.argmax() >>> answer_end_index = outputs.end_logits.argmax() >>> answer_offset = len(tokenizer(question)[0]) >>> predict_answer_tokens = inputs.input_ids[ ... 0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1 ... ] >>> predicted = tokenizer.decode(predict_answer_tokens) >>> predicted ' a nice puppet' ```""" if self.zo_training: return self.opt.zo_forward( input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, start_positions, end_positions, use_cache, output_attentions, output_hidden_states, return_dict, **kwargs) else: return self.opt.zo_eval_forward( input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, start_positions, end_positions, use_cache, output_attentions, output_hidden_states, return_dict, **kwargs) class OptimizerOPTDecoder(MeZO2SGD): def init_zo2(self): self.upload_stream = None self.offload_stream = None self.compute_stream = None self.zo_random_seed = None self.rstate = None self.rstate_queue = None self.last_rstate = None self.projected_grad = None self.init_zo2_upload() def init_zo2_upload(self): self.model.embed_tokens = self.model.embed_tokens.to(self.device) self.model.embed_positions = self.model.embed_positions.to(self.device) if self.model.project_out: self.model.project_out = self.model.project_out.to(self.device) if self.model.project_in: self.model.project_in = self.model.project_in.to(self.device) if self.model.final_layer_norm: self.model.final_layer_norm = self.model.final_layer_norm.to(self.device) self.num_blocks = len(self.model.layers) if self.offloading_blocks is not None: self.offloading_blocks = self.offloading_blocks else: self.offloading_blocks = list(range(self.num_blocks)) print(f"Transformer blocks {self.offloading_blocks} will be offloaded to {self.offloading_device}") for i in range(self.num_blocks): if i in self.offloading_blocks: continue else: self.model.layers[i] = self.model.layers[i].to(self.device) print(f"Upload block {i} to {self.device}.") @torch.inference_mode def inner_zo_forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, position_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.model.config.use_cache return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: # inputs_embeds = self.model.embed_tokens(input_ids) inputs_embeds1, inputs_embeds2 = self.task_compute_module(self.model.embed_tokens, inputs1={"input": input_ids}, inputs2={"input": input_ids}, grad=self.projected_grad) else: inputs_embeds1 = inputs_embeds2 = inputs_embeds batch_size, seq_length = input_shape # past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 past_key_values_length = 0 # required mask seq length can be calculated via length of past # mask_seq_length = past_key_values_length + seq_length mask_seq_length = seq_length past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 if cache_position is None: cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds1.shape[1], device=inputs_embeds1.device ) # embed positions if attention_mask is None: seq_length = past_seen_tokens + inputs_embeds1.shape[1] attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds1.device) # causal_mask = self.model._update_causal_mask( # attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions # ) causal_attention_mask1, causal_attention_mask2 = self.task_compute_function( self.model._update_causal_mask, inputs1={"attention_mask": attention_mask, "input_tensor": inputs_embeds1, "cache_position": cache_position, "past_key_values": past_key_values, "output_attentions": output_attentions}, inputs2={"attention_mask": attention_mask, "input_tensor": inputs_embeds2, "cache_position": cache_position, "past_key_values": past_key_values, "output_attentions": output_attentions}, compute_sync=False ) # pos_embeds = self.model.embed_positions(attention_mask, past_key_values_length) pos_embeds1, pos_embeds2 = self.task_compute_module(self.model.embed_positions, inputs1={"attention_mask": attention_mask, "past_key_values_length": past_key_values_length}, inputs2={"attention_mask": attention_mask, "past_key_values_length": past_key_values_length}, grad=self.projected_grad, compute_sync=False) if self.model.project_in is not None: # inputs_embeds = self.model.project_in(inputs_embeds) inputs_embeds1, inputs_embeds2 = self.task_compute_module(self.model.project_in, inputs1={"input": inputs_embeds1}, inputs2={"input": inputs_embeds2}, grad=self.projected_grad, compute_sync=False) # hidden_states = inputs_embeds + pos_embeds hidden_states1, hidden_states2 = self.task_compute_function(torch.add, inputs1={"input": inputs_embeds1, "other": pos_embeds1}, inputs2={"input": inputs_embeds2, "other": pos_embeds2}, compute_sync=False) if 0 in self.offloading_blocks: self.model.layers[0] = self.task_upload( module=self.model.layers[0], device=self.device ) # if self.model.gradient_checkpointing and self.model.training: # if use_cache: # logger.warning_once( # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." # ) # use_cache = False # # decoder layers # all_hidden_states = () if output_hidden_states else None # all_self_attns = () if output_attentions else None # next_decoder_cache = () if use_cache else None # check if head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask], ["head_mask"]): if attn_mask is not None: if attn_mask.size()[0] != (len(self.model.layers)): raise ValueError( f"The `{mask_name}` should be specified for {len(self.model.layers)} layers, but it is for" f" {head_mask.size()[0]}." ) N = len(self.model.layers) for i in range(1, N): if i != 1: if i-2 in self.offloading_blocks: self.model.layers[i-2] = self.task_offload( module=self.model.layers[i-2], device=self.offloading_device) layer_outputs1, layer_outputs2 = self.task_compute_module( self.model.layers[i-1], inputs1={"hidden_states": hidden_states1, "attention_mask": causal_attention_mask1, "layer_head_mask": (head_mask[i-1] if head_mask is not None else None), "output_attentions": output_attentions}, inputs2={"hidden_states": hidden_states2, "attention_mask": causal_attention_mask2, "layer_head_mask": (head_mask[i-1] if head_mask is not None else None), "output_attentions": output_attentions}, grad=self.projected_grad) # hidden_states = layer_outputs[0] hidden_states1, hidden_states2 = self.task_compute_function( fn=fn_get_opt_decoder_hidden_states_from_layer_outputs, inputs1={"input": layer_outputs1}, inputs2={"input": layer_outputs2}, compute_sync=False ) if i in self.offloading_blocks: self.model.layers[i] = self.task_upload( module=self.model.layers[i], device=self.device) if N-2 in self.offloading_blocks: self.model.layers[N-2] = self.task_offload( module=self.model.layers[N-2], device=self.offloading_device) layer_outputs1, layer_outputs2 = self.task_compute_module( self.model.layers[N-1], inputs1={"hidden_states": hidden_states1, "attention_mask": causal_attention_mask1, "layer_head_mask": (head_mask[i-1] if head_mask is not None else None), "output_attentions": output_attentions}, inputs2={"hidden_states": hidden_states2, "attention_mask": causal_attention_mask2, "layer_head_mask": (head_mask[i-1] if head_mask is not None else None), "output_attentions": output_attentions}, grad=self.projected_grad) hidden_states1, hidden_states2 = self.task_compute_function( fn=fn_get_opt_decoder_hidden_states_from_layer_outputs, inputs1={"input": layer_outputs1}, inputs2={"input": layer_outputs2}, compute_sync=False ) if N-1 in self.offloading_blocks: self.model.layers[N-1] = self.task_offload( module=self.model.layers[N-1], device=self.offloading_device) if self.model.final_layer_norm is not None: # hidden_states = self.model.final_layer_norm(hidden_states) hidden_states1, hidden_states2 = self.task_compute_module( module=self.model.final_layer_norm, inputs1={"input": hidden_states1}, inputs2={"input": hidden_states2}, grad=self.projected_grad, weight_decay=0.) if self.model.project_out is not None: # hidden_states = self.model.project_out(hidden_states) hidden_states1, hidden_states2 = self.task_compute_module( module=self.model.project_out, inputs1={"input": hidden_states1}, inputs2={"input": hidden_states2}, grad=self.projected_grad, compute_sync=False) return hidden_states1, hidden_states2 @torch.inference_mode def inner_zo_eval_forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, position_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.model.config.use_cache return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: # inputs_embeds = self.model.embed_tokens(input_ids) inputs_embeds = self.task_compute_module(self.model.embed_tokens, inputs1={"input": input_ids}, inputs2=None, grad=None) batch_size, seq_length = input_shape past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 # required mask seq length can be calculated via length of past mask_seq_length = past_key_values_length + seq_length past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 if cache_position is None: cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) # embed positions if attention_mask is None: attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) # causal_mask = self.model._update_causal_mask( # attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions # ) causal_attention_mask = self.task_compute_function( self.model._update_causal_mask, inputs1={"attention_mask": attention_mask, "input_tensor": inputs_embeds, "cache_position": cache_position, "past_key_values": past_key_values, "output_attentions": output_attentions}, inputs2=None ) # pos_embeds = self.model.embed_positions(attention_mask, past_key_values_length) pos_embeds = self.task_compute_module(self.model.embed_positions, inputs1={"attention_mask": attention_mask, "past_key_values_length": past_key_values_length}, inputs2=None, grad=None, compute_sync=False) if self.model.project_in is not None: # inputs_embeds = self.model.project_in(inputs_embeds) inputs_embeds = self.task_compute_module(self.model.project_in, inputs1={"input": inputs_embeds}, inputs2=None, grad=None, compute_sync=False) # hidden_states = inputs_embeds + pos_embeds hidden_states = self.task_compute_function(torch.add, inputs1={"input": inputs_embeds, "other": pos_embeds}, inputs2=None, compute_sync=False) if self.model.gradient_checkpointing and self.model.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False # decoder layers # all_hidden_states = () if output_hidden_states else None # all_self_attns = () if output_attentions else None # next_decoder_cache = () if use_cache else None all_hidden_states = self.task_compute_function(init_all_hidden_states, inputs1={"output_hidden_states": output_hidden_states}, inputs2=None, compute_sync=False) all_self_attns = self.task_compute_function(init_all_self_attns, inputs1={"output_attentions": output_attentions}, inputs2=None, compute_sync=False) next_decoder_cache = self.task_compute_function(init_next_decoder_cache, inputs1={"use_cache": use_cache}, inputs2=None, compute_sync=False) if 0 in self.offloading_blocks: self.model.layers[0] = self.task_upload( module=self.model.layers[0], device=self.device ) # check if head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask], ["head_mask"]): if attn_mask is not None: if attn_mask.size()[0] != (len(self.model.layers)): raise ValueError( f"The `{mask_name}` should be specified for {len(self.model.layers)} layers, but it is for" f" {head_mask.size()[0]}." ) N = len(self.model.layers) for i in range(1, N): if i != 1: if i-2 in self.offloading_blocks: self.model.layers[i-2] = self.task_offload( module=self.model.layers[i-2], device=self.offloading_device) all_hidden_states = self.task_compute_function( fn=update_all_hidden_states, inputs1={"output_hidden_states": output_hidden_states, "all_hidden_states": all_hidden_states, "hidden_states": hidden_states}, inputs2=None, compute_sync=False) past_key_value = self.task_compute_function( fn=get_past_key_value, inputs1={"past_key_values": past_key_values, "idx": i}, inputs2=None, compute_sync=False) layer_outputs = self.task_compute_module( self.model.layers[i-1], inputs1={"hidden_states": hidden_states, "attention_mask": causal_attention_mask, "layer_head_mask": (head_mask[i-1] if head_mask is not None else None), "past_key_value": past_key_value, "output_attentions": output_attentions, "use_cache": use_cache}, inputs2=None, grad=None) # hidden_states = layer_outputs[0] hidden_states = self.task_compute_function( fn=fn_get_opt_decoder_hidden_states_from_layer_outputs, inputs1={"input": layer_outputs}, inputs2=None, compute_sync=False) next_decoder_cache = self.task_compute_function( fn=update_next_decoder_cache, inputs1={"use_cache": use_cache, "next_decoder_cache": next_decoder_cache, "layer_outputs": layer_outputs, "output_attentions": output_attentions}, inputs2=None, compute_sync=False) all_self_attns = self.task_compute_function( fn=update_all_self_attns, inputs1={"output_attentions": output_attentions, "all_self_attns": all_self_attns, "layer_outputs": layer_outputs}, inputs2=None, compute_sync=False) # an unknown bug here, need to synchronize the stream to avoid memory leak (only apears in opt-350m) if i in range(1, N-1, 2) and i in self.offloading_blocks: self.compute_stream.synchronize() # a weird but useful trick to avoid memory leak if i in self.offloading_blocks: self.model.layers[i] = self.task_upload( module=self.model.layers[i], device=self.device) if N-2 in self.offloading_blocks: self.model.layers[N-2] = self.task_offload( module=self.model.layers[N-2], device=self.offloading_device) all_hidden_states = self.task_compute_function( fn=update_all_hidden_states, inputs1={"output_hidden_states": output_hidden_states, "all_hidden_states": all_hidden_states, "hidden_states": hidden_states}, inputs2=None) layer_outputs = self.task_compute_module( self.model.layers[N-1], inputs1={"hidden_states": hidden_states, "attention_mask": causal_attention_mask, "layer_head_mask": (head_mask[i-1] if head_mask is not None else None), "past_key_value": past_key_value, "output_attentions": output_attentions, "use_cache": use_cache}, inputs2=None, grad=None) hidden_states = self.task_compute_function( fn=fn_get_opt_decoder_hidden_states_from_layer_outputs, inputs1={"input": layer_outputs}, inputs2=None, compute_sync=False) next_decoder_cache = self.task_compute_function( fn=update_next_decoder_cache, inputs1={"use_cache": use_cache, "next_decoder_cache": next_decoder_cache, "layer_outputs": layer_outputs, "output_attentions": output_attentions}, inputs2=None, compute_sync=False) all_self_attns = self.task_compute_function( fn=update_all_self_attns, inputs1={"output_attentions": output_attentions, "all_self_attns": all_self_attns, "layer_outputs": layer_outputs}, inputs2=None, compute_sync=False ) if N-1 in self.offloading_blocks: self.model.layers[N-1] = self.task_offload( module=self.model.layers[N-1], device=self.offloading_device) if self.model.final_layer_norm is not None: # hidden_states = self.model.final_layer_norm(hidden_states) hidden_states = self.task_compute_module( module=self.model.final_layer_norm, inputs1={"input": hidden_states}, inputs2=None, grad=None) if self.model.project_out is not None: # hidden_states = self.model.project_out(hidden_states) hidden_states = self.task_compute_module( module=self.model.project_out, inputs1={"input": hidden_states}, inputs2=None, grad=None, compute_sync=False) # add hidden states from the last decoder layer # if output_hidden_states: # all_hidden_states += (hidden_states,) all_hidden_states = self.task_compute_function( fn=update_all_hidden_states, inputs1={"output_hidden_states": output_hidden_states, "all_hidden_states": all_hidden_states, "hidden_states": hidden_states}, inputs2=None, compute_sync=False ) next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) class OptimizerOPTModel(MeZO2SGD): def init_zo2(self): self.upload_stream = None self.offload_stream = None self.compute_stream = None self.zo_random_seed = None self.rstate = None self.rstate_queue = None self.last_rstate = None self.projected_grad = None self.init_zo2_upload() def init_zo2_upload(self): ... @torch.inference_mode def inner_zo_forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.model.config.use_cache return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) self.model.decoder.zo_training = True self.assign_zo2_attributes(self, self.model.decoder.opt) output = self.model.decoder( input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) self.assign_zo2_attributes(self.model.decoder.opt, self) return output @torch.inference_mode def inner_zo_eval_forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.model.config.use_cache return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict self.model.decoder.zo_training = False self.assign_zo2_attributes(self, self.model.decoder.opt) # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.model.decoder( input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) self.assign_zo2_attributes(self.model.decoder.opt, self) if not return_dict: return decoder_outputs return BaseModelOutputWithPast( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, hidden_states=decoder_outputs.hidden_states, attentions=decoder_outputs.attentions, ) class OptimizerOPTForCausalLM(MeZO2SGD): def init_zo2_upload(self): self.model.lm_head = self.model.lm_head.to(self.device) @torch.inference_mode def inner_zo_forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs ) -> Union[Tuple, CausalLMOutputWithPast]: """ copy the original forward code and replace all 'self' to 'self.model'. """ output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) self.model.model.decoder.zo_training = True self.assign_zo2_attributes(self, self.model.model.decoder.opt) hidden_states1, hidden_states2 = self.model.model.decoder( input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) self.assign_zo2_attributes(self.model.model.decoder.opt, self) # logits = self.model.lm_head(outputs[0]).contiguous() logits1, logits2 = self.task_compute_module(self.model.lm_head, inputs1={"input": hidden_states1}, inputs2={"input": hidden_states2}, grad=self.projected_grad) if self.model.zo_train_loss_fn_pre_hooks != []: for pre_hook_fn in self.model.zo_train_loss_fn_pre_hooks: (input_ids, logits1, labels), (input_ids, logits2, labels) = \ self.task_compute_function(pre_hook_fn, inputs1={"self": self.model, "input_ids": input_ids, "logits": logits1, "labels": labels}, inputs2={"self": self.model, "input_ids": input_ids, "logits": logits2, "labels": labels}) # loss = None if self.model.zo_custom_train_loss_fn: loss1, loss2 = self.task_compute_function(self.model.zo_custom_train_loss_fn, inputs1={"self": self.model, "input_ids": input_ids, "logits": logits1, "labels": labels, **kwargs}, inputs2={"self": self.model, "input_ids": input_ids, "logits": logits2, "labels": labels, **kwargs}) elif labels is not None: # Shift so that tokens < n predict n # shift_logits = logits[..., :-1, :].contiguous() shift_logits1, shift_logits2 = self.task_compute_function( fn=get_shift_logits, inputs1={"logits": logits1}, inputs2={"logits": logits2}) # shift_labels = labels[..., 1:].contiguous() shift_labels1, shift_labels2 = self.task_compute_function( fn=get_shift_labels, inputs1={"labels": labels}, inputs2={"labels": labels}) # Flatten the tokens loss_fct = CrossEntropyLoss() # loss = loss_fct(shift_logits.view(-1, self.model.config.vocab_size), shift_labels.view(-1)) loss1, loss2 = self.task_compute_function( fn=loss_fct, inputs1={"input": shift_logits1.view(-1, self.model.config.vocab_size), "target": shift_labels1.view(-1)}, inputs2={"input": shift_logits2.view(-1, self.model.config.vocab_size), "target": shift_labels2.view(-1)}) if self.model.zo_train_loss_fn_post_hooks != []: for post_hook_fn in self.model.zo_train_loss_fn_post_hooks: (loss1, input_ids, logits1, labels), (loss2, input_ids, logits2, labels) = \ self.task_compute_function(post_hook_fn, inputs1={"self": self.model, "loss": loss1, "input_ids": input_ids, "logits": logits1, "labels": labels}, inputs2={"self": self.model, "loss": loss2, "input_ids": input_ids, "logits": logits2, "labels": labels}) return loss1, loss2 @torch.inference_mode def inner_zo_eval_forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict self.model.model.decoder.zo_training = False self.assign_zo2_attributes(self, self.model.model.decoder.opt) outputs = self.model.model.decoder( input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) self.assign_zo2_attributes(self.model.model.decoder.opt, self) hidden_states = self.task_compute_function( fn_get_opt_decoder_hidden_states_from_layer_outputs, inputs1={"input": outputs}, inputs2=None, compute_sync=False ) logits = self.task_compute_module(self.model.lm_head, inputs1={"input": hidden_states}, inputs2=None, grad=self.projected_grad) if self.model.zo_eval_loss_fn_pre_hooks != []: for pre_hook_fn in self.model.zo_eval_loss_fn_pre_hooks: input_ids, logits, labels = \ self.task_compute_function(pre_hook_fn, inputs1=([self.model], {"input_ids": input_ids, "logits": logits, "labels": labels}), inputs2=None) loss = None if self.model.zo_custom_eval_loss_fn: loss = self.task_compute_function( fn=self.model.zo_custom_eval_loss_fn, inputs1=([self.model], {"input_ids": input_ids, "logits": logits, "labels": labels, **kwargs}), inputs2=None, compute_sync=False ) elif labels is not None: # Shift so that tokens < n predict n # shift_logits = logits[..., :-1, :].contiguous() shift_logits = self.task_compute_function( fn=get_shift_logits, inputs1={"logits": logits}, inputs2=None) # shift_labels = labels[..., 1:].contiguous() shift_labels = self.task_compute_function( fn=get_shift_labels, inputs1={"labels": labels}, inputs2=None) # Flatten the tokens loss_fct = CrossEntropyLoss() # loss = loss_fct(shift_logits.view(-1, self.model.config.vocab_size), shift_labels.view(-1)) loss = self.task_compute_function( fn=loss_fct, inputs1={"input": shift_logits.view(-1, self.model.config.vocab_size), "target": shift_labels.view(-1)}, inputs2=None) if self.model.zo_eval_loss_fn_post_hooks != []: for post_hook_fn in self.model.zo_eval_loss_fn_post_hooks: output, input_ids, logits, labels = \ self.task_compute_function(post_hook_fn, inputs1=([self.model], {"loss": loss, "input_ids": input_ids, "logits": logits, "labels": labels}), inputs2=None) if not return_dict: output = (logits,) + outputs[1] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class OptimizerOPTForSequenceClassification(MeZO2SGD): def init_zo2_upload(self): self.model.score = self.model.score.to(self.device) @torch.inference_mode def inner_zo_forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs ) -> Union[Tuple, SequenceClassifierOutputWithPast]: """ copy the original forward code and replace all 'self' to 'self.model'. """ return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict self.model.model.decoder.zo_training = True self.assign_zo2_attributes(self, self.model.model.opt) hidden_states1, hidden_states2 = self.model.model( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) self.assign_zo2_attributes(self.model.model.opt, self) # hidden_states = transformer_outputs[0] # logits = self.model.score(hidden_states) logits1, logits2 = self.task_compute_module(self.model.score, inputs1={"input": hidden_states1}, inputs2={"input": hidden_states2}, grad=self.projected_grad) if input_ids is not None: batch_size, sequence_length = input_ids.shape[:2] else: batch_size, sequence_length = inputs_embeds.shape[:2] if self.model.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: sequence_lengths = (torch.ne(input_ids, self.model.config.pad_token_id).sum(-1) - 1).to(logits1.device) else: sequence_lengths = -1 logger.warning( f"{self.model.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " "unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) # pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] pooled_logits1, pooled_logits2 = self.task_compute_function( fn=get_pooled_logits, inputs1={"logits": logits1, "batch_size": batch_size, "sequence_lengths": sequence_lengths}, inputs2={"logits": logits2, "batch_size": batch_size, "sequence_lengths": sequence_lengths},) if self.model.zo_train_loss_fn_pre_hooks != []: for pre_hook_fn in self.model.zo_train_loss_fn_pre_hooks: (input_ids, pooled_logits1, labels), (input_ids, pooled_logits2, labels) = \ self.task_compute_function(pre_hook_fn, inputs1={"self": self, "input_ids": input_ids, "logits": pooled_logits1, "labels": labels}, inputs2={"self": self, "input_ids": input_ids, "logits": pooled_logits2, "labels": labels}) # loss = None if self.model.zo_custom_train_loss_fn: loss1, loss2 = self.task_compute_function(self.model.zo_custom_train_loss_fn, inputs1={"self": self.model, "input_ids": input_ids, "logits": pooled_logits1, "labels": labels, **kwargs}, inputs2={"self": self.model, "input_ids": input_ids, "logits": pooled_logits2, "labels": labels, **kwargs}) elif labels is not None: if self.model.config.problem_type is None: if self.model.num_labels == 1: self.model.config.problem_type = "regression" elif self.model.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.model.config.problem_type = "single_label_classification" else: self.model.config.problem_type = "multi_label_classification" if self.model.config.problem_type == "regression": loss_fct = MSELoss() if self.model.num_labels == 1: # loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) loss1, loss2 = self.task_compute_function( fn=loss_fct, inputs1={"input": pooled_logits1.squeeze(), "target": labels.squeeze()}, inputs2={"input": pooled_logits2.squeeze(), "target": labels.squeeze()},) else: # loss = loss_fct(pooled_logits, labels) loss1, loss2 = self.task_compute_function( fn=loss_fct, inputs1={"input": pooled_logits1, "target": labels}, inputs2={"input": pooled_logits2, "target": labels},) elif self.model.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() # loss = loss_fct(pooled_logits.view(-1, self.model.num_labels), labels.view(-1)) loss1, loss2 = self.task_compute_function( fn=loss_fct, inputs1={"input": pooled_logits1.view(-1, self.model.num_labels), "target": labels.view(-1)}, inputs2={"input": pooled_logits2.view(-1, self.model.num_labels), "target": labels.view(-1)},) elif self.model.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() # loss = loss_fct(pooled_logits, labels) loss1, loss2 = self.task_compute_function( fn=loss_fct, inputs1={"input": pooled_logits1, "target": labels}, inputs2={"input": pooled_logits2, "target": labels},) if self.model.zo_train_loss_fn_post_hooks != []: for post_hook_fn in self.model.zo_train_loss_fn_post_hooks: (loss1, input_ids, pooled_logits1, labels), (loss2, input_ids, pooled_logits2, labels) = \ self.task_compute_function(post_hook_fn, inputs1={"self": self.model, "loss": loss1, "input_ids": input_ids, "logits": pooled_logits1, "labels": labels}, inputs2={"self": self.model, "loss": loss2, "input_ids": input_ids, "logits": pooled_logits2, "labels": labels}) return loss1, loss2 @torch.inference_mode def inner_zo_eval_forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs ) -> Union[Tuple, SequenceClassifierOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.model.model.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.model.model.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.model.model.config.use_return_dict self.model.model.zo_training = False self.assign_zo2_attributes(self, self.model.model.opt) transformer_outputs = self.model.model( input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) self.assign_zo2_attributes(self.model.model.opt, self) hidden_states = self.task_compute_function( fn=fn_get_opt_decoder_hidden_states_from_layer_outputs, inputs1={"input": transformer_outputs}, inputs2=None) logits = self.task_compute_module(self.model.score, inputs1={"input": hidden_states}, inputs2=None, grad=self.projected_grad) pooled_logits = self.task_compute_function( fn=get_opt_sequence_classification_pooled_logits, inputs1=([self.model], {"logits": logits, "input_ids": input_ids, "inputs_embeds": inputs_embeds}), inputs2=None, compute_sync=False) if self.model.zo_eval_loss_fn_pre_hooks != []: for pre_hook_fn in self.model.zo_eval_loss_fn_pre_hooks: input_ids, logits, labels = self.task_compute_function(pre_hook_fn, inputs1=([self.model], {"input_ids": input_ids, "logits": logits, "labels": labels}), inputs2=None, compute_sync=False) loss = None if self.model.zo_custom_eval_loss_fn: loss = self.task_compute_function( fn=self.model.zo_custom_eval_loss_fn, inputs1=([self.model], {"input_ids": input_ids, "pooled_logits": pooled_logits, "labels": labels, **kwargs}), inputs2=None, compute_sync=False ) elif labels is not None: loss = self.task_compute_function( fn=get_opt_sequence_classification_loss, inputs1=([self.model], {"loss": loss, "pooled_logits": pooled_logits, "labels": labels}), inputs2=None, compute_sync=False ) if self.model.zo_eval_loss_fn_post_hooks != []: for post_hook_fn in self.model.zo_eval_loss_fn_post_hooks: transformer_outputs, input_ids, logits, labels = self.task_compute_function(post_hook_fn, inputs1=([self.model], {"transformer_outputs": transformer_outputs, "input_ids": input_ids, "pooled_logits": pooled_logits, "labels": labels}), inputs2=None, compute_sync=False) if not return_dict: transformer_outputs = (logits,) + transformer_outputs[1:] return ((loss,) + transformer_outputs) if loss is not None else transformer_outputs return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) class OptimizerOPTForQuestionAnswering(MeZO2SGD): def init_zo2_upload(self): self.model.qa_outputs = self.model.qa_outputs.to(self.device) @torch.inference_mode def inner_zo_forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, start_positions: Optional[torch.LongTensor] = None, end_positions: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs ) -> Union[Tuple, QuestionAnsweringModelOutput]: """ copy the original forward code and replace all 'self' to 'self.model'. """ return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict self.model.model.decoder.zo_training = True self.assign_zo2_attributes(self, self.model.model.opt) hidden_states1, hidden_states2 = self.model.model( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) self.assign_zo2_attributes(self.model.model.opt, self) # hidden_states = transformer_outputs[0] # logits = self.model.qa_outputs(hidden_states) logits1, logits2 = self.task_compute_module(self.model.qa_outputs, inputs1={"input": hidden_states1}, inputs2={"input": hidden_states2}, grad=self.projected_grad) # start_logits, end_logits = logits.split(1, dim=-1) # start_logits = start_logits.squeeze(-1).contiguous() # end_logits = end_logits.squeeze(-1).contiguous() (start_logits1, end_logits1), (start_logits2, end_logits2) = self.task_compute_function( fn=get_start_logits_and_end_logits, inputs1={"logits": logits1}, inputs2={"logits": logits2},) if self.model.zo_train_loss_fn_pre_hooks != []: for pre_hook_fn in self.model.zo_train_loss_fn_pre_hooks: (input_ids, start_logits1, start_positions, end_logits1, end_positions), (input_ids, start_logits2, start_positions, end_logits2, end_positions) = \ self.task_compute_function(pre_hook_fn, inputs1={"self": self, "input_ids": input_ids, "start_logits": start_logits1, "start_positions": start_positions, "end_logits": end_logits1, "end_positions": end_positions}, inputs2={"self": self, "input_ids": input_ids, "start_logits": start_logits2, "start_positions": start_positions, "end_logits": end_logits2, "end_positions": end_positions}) # total_loss = None if self.model.zo_custom_train_loss_fn: loss1, loss2 = self.task_compute_function(self.model.zo_custom_train_loss_fn, inputs1={"self": self.model, "input_ids": input_ids, "start_logits": start_logits1, "start_positions": start_positions, "end_logits": end_logits1, "end_positions": end_positions, **kwargs}, inputs2={"self": self.model, "input_ids": input_ids, "start_logits": start_logits2, "start_positions": start_positions, "end_logits": end_logits2, "end_positions": end_positions, **kwargs}) elif start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits1.size(1) start_positions = start_positions.clamp(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index) # start_loss = loss_fct(start_logits, start_positions) # end_loss = loss_fct(end_logits, end_positions) # total_loss = (start_loss + end_loss) / 2 loss1, loss2 = self.task_compute_function( fn=get_qa_loss, inputs1={"loss_fct": loss_fct, "start_logits": start_logits1, "start_positions": start_positions, "end_logits": end_logits1, "end_positions": end_positions}, inputs2={"loss_fct": loss_fct, "start_logits": start_logits2, "start_positions": start_positions, "end_logits": end_logits2, "end_positions": end_positions}) if self.model.zo_train_loss_fn_post_hooks != []: for post_hook_fn in self.model.zo_train_loss_fn_post_hooks: (loss1, input_ids, start_logits1, start_positions, end_logits1, end_positions), (loss2, input_ids, start_logits2, start_positions, end_logits2, end_positions) = \ self.task_compute_function(post_hook_fn, inputs1={"self": self.model, "loss": loss1, "input_ids": input_ids, "start_logits": start_logits1, "start_positions": start_positions, "end_logits": end_logits1, "end_positions": end_positions}, inputs2={"self": self.model, "loss": loss2, "input_ids": input_ids, "start_logits": start_logits2, "start_positions": start_positions, "end_logits": end_logits2, "end_positions": end_positions}) return loss1, loss2 @torch.inference_mode def inner_zo_eval_forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, start_positions: Optional[torch.LongTensor] = None, end_positions: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs ) -> Union[Tuple, QuestionAnsweringModelOutput]: return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict self.model.model.zo_training = False self.assign_zo2_attributes(self, self.model.model.opt) transformer_outputs = self.model.model( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) self.assign_zo2_attributes(self.model.model.opt, self) hidden_states = self.task_compute_function( fn=fn_get_opt_decoder_hidden_states_from_layer_outputs, inputs1={"input": transformer_outputs}, inputs2=None) logits = self.task_compute_module(self.model.qa_outputs, inputs1={"input": hidden_states}, inputs2=None, grad=self.projected_grad) start_logits, end_logits = self.task_compute_function( fn=get_start_logits_and_end_logits, inputs1={"logits": logits}, inputs2=None, compute_sync=False) if self.model.zo_eval_loss_fn_pre_hooks != []: for pre_hook_fn in self.model.zo_eval_loss_fn_pre_hooks: input_ids, start_logits, start_positions, end_logits, end_positions = self.task_compute_function(pre_hook_fn, inputs1=([self.model], {"input_ids": input_ids, "start_logits": start_logits, "start_positions": start_positions, "end_logits": end_logits, "end_positions": end_positions}), inputs2=None, compute_sync=False) total_loss = None if self.model.zo_custom_eval_loss_fn: total_loss = self.task_compute_function(self.model.zo_custom_eval_loss_fn, inputs1=([self.model], {"input_ids": input_ids, "start_logits": start_logits, "start_positions": start_positions, "end_logits": end_logits, "end_positions": end_positions, **kwargs}), inputs2=None, compute_sync=False) elif start_positions is not None and end_positions is not None: total_loss = self.task_compute_function( fn=get_opt_question_answering_loss, inputs1={"total_loss": total_loss, "start_logits": start_logits, "start_positions": start_positions, "end_logits": end_logits, "end_positions": end_positions}, inputs2=None, compute_sync=False) if self.model.zo_eval_loss_fn_post_hooks != []: for post_hook_fn in self.model.zo_eval_loss_fn_post_hooks: transformer_outputs, input_ids, start_logits, start_positions, end_logits, end_positions = self.task_compute_function(post_hook_fn, inputs1=([self.model], {"transformer_outputs": transformer_outputs, "input_ids": input_ids, "start_logits": start_logits, "start_positions": start_positions, "end_logits": end_logits, "end_positions": end_positions}), inputs2=None, compute_sync=False) if not return_dict: output = (start_logits, end_logits) + transformer_outputs[2:] return ((total_loss,) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) ================================================ FILE: zo2/model/huggingface/qwen3/__init__.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 from . import ( mezo_sgd, ) def get_qwen3_for_causalLM(zo_config): zo2_supported_configs = { "mezo-sgd": mezo_sgd.get_qwen3_for_causalLM_mezo_sgd, } return zo2_supported_configs[zo_config.zo_method](zo_config) ================================================ FILE: zo2/model/huggingface/qwen3/mezo_sgd/__init__.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 from . import zo, zo2 from .....config.mezo_sgd import MeZOSGDConfig def get_qwen3_for_causalLM_mezo_sgd(config: MeZOSGDConfig): return zo2.Qwen3ForCausalLM if config.zo2 else zo.Qwen3ForCausalLM ================================================ FILE: zo2/model/huggingface/qwen3/mezo_sgd/utils.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import torch def fn_get_qwen3_decoder_hidden_states_from_layer_outputs(input): return input[0] def fn_get_qwen3_sliced_logits_from_hidden_states(hidden_states, slice_indices): return hidden_states[:, slice_indices, :] ================================================ FILE: zo2/model/huggingface/qwen3/mezo_sgd/zo.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import torch import torch.nn as nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from transformers.models.qwen3 import modeling_qwen3 from transformers.models.qwen3.modeling_qwen3 import ( Qwen3Config, Qwen3PreTrainedModel, Qwen3RMSNorm, Qwen3RotaryEmbedding, Qwen3DecoderLayer, CausalLMOutputWithPast, BaseModelOutputWithPast, KwargsForCausalLM, can_return_tuple, deprecate_kwarg, add_start_docstrings_to_model_forward, replace_return_docstrings, QWEN3_INPUTS_DOCSTRING, _CONFIG_FOR_DOC, ) from transformers.utils import logging import random from typing import List, Optional, Tuple, Union, Unpack from ....base import BaseZOModel from .....optimizer.mezo_sgd.zo import MeZOSGD from .....config.mezo_sgd import MeZOSGDConfig logger = logging.get_logger(__name__) class Qwen3Model(modeling_qwen3.Qwen3Model, Qwen3PreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen3DecoderLayer`] Args: config: Qwen3Config """ def __init__(self, config: Qwen3Config): config.use_cache = False Qwen3PreTrainedModel.__init__(self, config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.rotary_emb = Qwen3RotaryEmbedding(config=config) self.layers = nn.ModuleList( [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() class Qwen3ForCausalLM(modeling_qwen3.Qwen3ForCausalLM, Qwen3PreTrainedModel, BaseZOModel): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config: Qwen3Config): Qwen3PreTrainedModel.__init__(self, config) BaseZOModel.__init__(self) self.model = Qwen3Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def zo_init(self, zo_config): self.opt = OptimizerQwen3ForCausalLM(model=self, config=zo_config) @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. logits_to_keep (`int` or `torch.Tensor`, *optional*): If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: Example: ```python >>> from transformers import AutoTokenizer, Qwen3ForCausalLM >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B") >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" if self.zo_training: use_cache = False return self.opt.zo_forward( input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, cache_position, logits_to_keep, **kwargs) else: return self.opt.zo_eval_forward(super().forward, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, cache_position, logits_to_keep, **kwargs) class OptimizerQwen3ForCausalLM(MeZOSGD): @torch.inference_mode def inner_zo_forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> CausalLMOutputWithPast: """ copy the original forward code and replace all 'self' to 'self.model'. """ output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs: BaseModelOutputWithPast = self.model.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.model.lm_head(hidden_states[:, slice_indices, :]) if self.model.zo_train_loss_fn_pre_hooks != []: for pre_hook_fn in self.model.zo_train_loss_fn_pre_hooks: input_ids, logits, labels = pre_hook_fn(self.model, input_ids, logits, labels) loss = None if labels is not None: if self.model.zo_custom_train_loss_fn: loss = self.model.zo_custom_train_loss_fn(self.model, input_ids, logits, labels, **kwargs) else: loss = self.model.loss_function(logits=logits, labels=labels, vocab_size=self.model.config.vocab_size, **kwargs) if self.model.zo_train_loss_fn_post_hooks != []: for post_hook_fn in self.model.zo_train_loss_fn_post_hooks: loss, input_ids, logits, labels = post_hook_fn(self.model, loss, input_ids, logits, labels) # add --> only return loss return loss.detach() @torch.inference_mode def inner_zo_eval_forward( self, eval_fn, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> CausalLMOutputWithPast: if self.model.zo_eval_loss_fn_pre_hooks != []: for pre_hook_fn in self.model.zo_eval_loss_fn_pre_hooks: input_ids, logits, labels = pre_hook_fn(self.model, input_ids, logits, labels) if self.model.zo_custom_eval_loss_fn: output = eval_fn(input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, None, use_cache, output_attentions, output_hidden_states, cache_position, logits_to_keep, **kwargs) logits = output["logits"] loss = None if labels is not None: loss = self.model.zo_custom_eval_loss_fn(self.model, input_ids, logits, labels, **kwargs) output = CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=output.past_key_values, hidden_states=output.hidden_states, attentions=output.attentions, ) else: output = eval_fn(input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, cache_position, logits_to_keep, **kwargs) if self.model.zo_eval_loss_fn_post_hooks != []: for post_hook_fn in self.model.zo_eval_loss_fn_post_hooks: output, input_ids, logits, labels = post_hook_fn(self.model, output, input_ids, logits, labels) return output ================================================ FILE: zo2/model/huggingface/qwen3/mezo_sgd/zo2.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import random import torch import torch.nn as nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from transformers.models.qwen3 import modeling_qwen3 from transformers.models.qwen3.modeling_qwen3 import ( Qwen3Config, Qwen3PreTrainedModel, Qwen3RMSNorm, Qwen3RotaryEmbedding, Qwen3DecoderLayer, CausalLMOutputWithPast, BaseModelOutputWithPast, KwargsForCausalLM, FlashAttentionKwargs, partial, can_return_tuple, deprecate_kwarg, add_start_docstrings_to_model_forward, replace_return_docstrings, QWEN3_INPUTS_DOCSTRING, _CONFIG_FOR_DOC, ) from transformers.utils import logging from typing import List, Optional, Tuple, Union, Unpack from ....base import BaseZOModel from .....optimizer.mezo_sgd.zo2 import MeZO2SGD from .....config.mezo_sgd import MeZOSGDConfig from .utils import * logger = logging.get_logger(__name__) class Qwen3Model(modeling_qwen3.Qwen3Model, Qwen3PreTrainedModel, BaseZOModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen3DecoderLayer`] Args: config: Qwen3Config """ def __init__(self, config: Qwen3Config): """ !!! Module register must follow the execution order. """ config.use_cache = False Qwen3PreTrainedModel.__init__(self, config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.rotary_emb = Qwen3RotaryEmbedding(config=config) self.layers = nn.ModuleList( [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def zo_init(self, zo_config): # Initialize ZO2 self.opt = OptimizerQwen3Model(model=self, config=zo_config) @can_return_tuple @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutputWithPast: if self.zo_training: return self.opt.inner_zo_forward(input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, cache_position, **flash_attn_kwargs) else: return self.opt.zo_eval_forward(input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, cache_position, **flash_attn_kwargs) class Qwen3ForCausalLM(modeling_qwen3.Qwen3ForCausalLM, Qwen3PreTrainedModel, BaseZOModel): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): Qwen3PreTrainedModel.__init__(self, config) BaseZOModel.__init__(self) self.model = Qwen3Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def zo_init(self, zo_config): self.model.zo_init(zo_config) # Initialize ZO2 self.opt = OptimizerQwen3ForCausalLM(model=self, config=zo_config) @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. logits_to_keep (`int` or `torch.Tensor`, *optional*): If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: Example: ```python >>> from transformers import AutoTokenizer, Qwen3ForCausalLM >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B") >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" if self.zo_training: return self.opt.zo_forward( input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, cache_position, logits_to_keep, **kwargs) else: return self.opt.zo_eval_forward( input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, cache_position, logits_to_keep, **kwargs) class OptimizerQwen3Model(MeZO2SGD): def init_zo2(self): self.upload_stream = None self.offload_stream = None self.compute_stream = None self.zo_random_seed = None self.rstate = None self.rstate_queue = None self.last_rstate = None self.projected_grad = None self.init_zo2_upload() def init_zo2_upload(self): self.model.embed_tokens = self.model.embed_tokens.to(self.device) self.model.rotary_emb = self.model.rotary_emb.to(self.device) self.model.norm = self.model.norm.to(self.device) self.num_blocks = len(self.model.layers) if self.offloading_blocks is not None: self.offloading_blocks = self.offloading_blocks else: self.offloading_blocks = list(range(self.num_blocks)) print(f"Transformer blocks {self.offloading_blocks} will be offloaded to {self.offloading_device}") for i in range(self.num_blocks): if i in self.offloading_blocks: continue else: self.model.layers[i] = self.model.layers[i].to(self.device) print(f"Upload block {i} to {self.device}.") @torch.inference_mode def inner_zo_forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states ) # use_cache = use_cache if use_cache is not None else self.model.config.use_cache use_cache = False if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.model.gradient_checkpointing and self.model.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) use_cache = False # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache if not isinstance(past_key_values, (type(None), Cache)): raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") if inputs_embeds is None: # inputs_embeds = self.model.embed_tokens(input_ids) inputs_embeds1, inputs_embeds2 = self.task_compute_module( self.model.embed_tokens, inputs1={"input": input_ids}, inputs2={"input": input_ids}, grad=self.projected_grad ) else: inputs_embeds1 = inputs_embeds2 = inputs_embeds if use_cache and past_key_values is None: past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds1.shape[1], device=inputs_embeds1.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask1, causal_mask2 = self.task_compute_function( self.model._update_causal_mask, inputs1={"attention_mask": attention_mask, "input_tensor": inputs_embeds1, "cache_position": cache_position, "past_key_values": past_key_values, "output_attentions": output_attentions}, inputs2={"attention_mask": attention_mask, "input_tensor": inputs_embeds2, "cache_position": cache_position, "past_key_values": past_key_values, "output_attentions": output_attentions}, compute_sync=False, ) hidden_states1, hidden_states2 = inputs_embeds1, inputs_embeds2 # create position embeddings to be shared across the decoder layers position_embeddings1, position_embeddings2 = self.task_compute_module( self.model.rotary_emb, inputs1={"x": hidden_states1, "position_ids": position_ids}, inputs2={"x": hidden_states2, "position_ids": position_ids}, grad=self.projected_grad, compute_sync=False ) if 0 in self.offloading_blocks: self.model.layers[0] = self.task_upload( module=self.model.layers[0], device=self.device ) N = self.model.config.num_hidden_layers for i in range(1, N): if i != 1: if i-2 in self.offloading_blocks: self.model.layers[i-2] = self.task_offload( module=self.model.layers[i-2], device=self.offloading_device) layer_outputs1, layer_outputs2 = self.task_compute_module( self.model.layers[i-1], inputs1={"hidden_states": hidden_states1, "attention_mask": causal_mask1, "position_ids": position_ids, "past_key_value": past_key_values, "output_attentions": output_attentions, "use_cache": use_cache, "cache_position": cache_position, "position_embeddings": position_embeddings1, **flash_attn_kwargs}, inputs2={"hidden_states": hidden_states2, "attention_mask": causal_mask2, "position_ids": position_ids, "past_key_value": past_key_values, "output_attentions": output_attentions, "use_cache": use_cache, "cache_position": cache_position, "position_embeddings": position_embeddings2, **flash_attn_kwargs}, grad=self.projected_grad ) hidden_states1, hidden_states2 = self.task_compute_function( fn=fn_get_qwen3_decoder_hidden_states_from_layer_outputs, inputs1={"input": layer_outputs1}, inputs2={"input": layer_outputs2}, compute_sync=False ) if i in self.offloading_blocks: self.model.layers[i] = self.task_upload( module=self.model.layers[i], device=self.device) if N-2 in self.offloading_blocks: self.model.layers[N-2] = self.task_offload( module=self.model.layers[N-2], device=self.offloading_device) layer_outputs1, layer_outputs2 = self.task_compute_module( self.model.layers[N-1], inputs1={"hidden_states": hidden_states1, "attention_mask": causal_mask1, "position_ids": position_ids, "past_key_value": past_key_values, "output_attentions": output_attentions, "use_cache": use_cache, "cache_position": cache_position, "position_embeddings": position_embeddings1, **flash_attn_kwargs}, inputs2={"hidden_states": hidden_states2, "attention_mask": causal_mask2, "position_ids": position_ids, "past_key_value": past_key_values, "output_attentions": output_attentions, "use_cache": use_cache, "cache_position": cache_position, "position_embeddings": position_embeddings2, **flash_attn_kwargs}, grad=self.projected_grad ) hidden_states1, hidden_states2 = self.task_compute_function( fn=fn_get_qwen3_decoder_hidden_states_from_layer_outputs, inputs1={"input": layer_outputs1}, inputs2={"input": layer_outputs2}, compute_sync=False ) if N-1 in self.offloading_blocks: self.model.layers[N-1] = self.task_offload( module=self.model.layers[N-1], device=self.offloading_device) hidden_states1, hidden_states2 = self.task_compute_module( module=self.model.norm, inputs1={"hidden_states": hidden_states1}, inputs2={"hidden_states": hidden_states2}, grad=self.projected_grad, # weight_decay=0. ) return hidden_states1, hidden_states2 @torch.inference_mode def inner_zo_eval_forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states ) # use_cache = use_cache if use_cache is not None else self.model.config.use_cache use_cache = False if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.model.gradient_checkpointing and self.model.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) use_cache = False # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache if not isinstance(past_key_values, (type(None), Cache)): raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") if inputs_embeds is None: # inputs_embeds = self.model.embed_tokens(input_ids) inputs_embeds = self.task_compute_module( self.model.embed_tokens, inputs1={"input": input_ids}, inputs2=None, grad=None ) if use_cache and past_key_values is None: past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self.task_compute_function( self.model._update_causal_mask, inputs1={"attention_mask": attention_mask, "input_tensor": inputs_embeds, "cache_position": cache_position, "past_key_values": past_key_values, "output_attentions": output_attentions}, inputs2=None, compute_sync=False, ) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = self.task_compute_module( self.model.rotary_emb, inputs1={"x": hidden_states, "position_ids": position_ids}, inputs2=None, grad=None, compute_sync=False ) if 0 in self.offloading_blocks: self.model.layers[0] = self.task_upload( module=self.model.layers[0], device=self.device ) N = self.model.config.num_hidden_layers for i in range(1, N): if i != 1: if i-2 in self.offloading_blocks: self.model.layers[i-2] = self.task_offload( module=self.model.layers[i-2], device=self.offloading_device) layer_outputs = self.task_compute_module( self.model.layers[i-1], inputs1={"hidden_states": hidden_states, "attention_mask": causal_mask, "position_ids": position_ids, "past_key_value": past_key_values, "output_attentions": output_attentions, "use_cache": use_cache, "cache_position": cache_position, "position_embeddings": position_embeddings, **flash_attn_kwargs}, inputs2=None, grad=None ) hidden_states = self.task_compute_function( fn=fn_get_qwen3_decoder_hidden_states_from_layer_outputs, inputs1={"input": layer_outputs}, inputs2=None, compute_sync=False ) if i in self.offloading_blocks: self.model.layers[i] = self.task_upload( module=self.model.layers[i], device=self.device) if N-2 in self.offloading_blocks: self.model.layers[N-2] = self.task_offload( module=self.model.layers[N-2], device=self.offloading_device) layer_outputs = self.task_compute_module( self.model.layers[N-1], inputs1={"hidden_states": hidden_states, "attention_mask": causal_mask, "position_ids": position_ids, "past_key_value": past_key_values, "output_attentions": output_attentions, "use_cache": use_cache, "cache_position": cache_position, "position_embeddings": position_embeddings, **flash_attn_kwargs}, inputs2=None, grad=None ) hidden_states = self.task_compute_function( fn=fn_get_qwen3_decoder_hidden_states_from_layer_outputs, inputs1={"input": layer_outputs}, inputs2=None, compute_sync=False ) if N-1 in self.offloading_blocks: self.model.layers[N-1] = self.task_offload( module=self.model.layers[N-1], device=self.offloading_device) hidden_states = self.task_compute_module( module=self.model.norm, inputs1={"hidden_states": hidden_states}, inputs2=None, grad=None, # weight_decay=0. ) return hidden_states class OptimizerQwen3ForCausalLM(MeZO2SGD): def init_zo2_upload(self): self.model.lm_head = self.model.lm_head.to(self.device) @torch.inference_mode def inner_zo_forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> CausalLMOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) self.model.model.zo_training = True self.assign_zo2_attributes(self, self.model.model.opt) hidden_states1, hidden_states2 = self.model.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, **kwargs, ) self.assign_zo2_attributes(self.model.model.opt, self) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep hidden_states1, hidden_states2 = self.task_compute_function( fn_get_qwen3_sliced_logits_from_hidden_states, inputs1={"hidden_states": hidden_states1, "slice_indices": slice_indices}, inputs2={"hidden_states": hidden_states2, "slice_indices": slice_indices}, ) logits1, logits2 = self.task_compute_module(self.model.lm_head, inputs1={"input": hidden_states1}, inputs2={"input": hidden_states2}, grad=self.projected_grad) if self.model.zo_train_loss_fn_pre_hooks != []: for pre_hook_fn in self.model.zo_train_loss_fn_pre_hooks: (input_ids, logits1, labels), (input_ids, logits2, labels) = \ self.task_compute_function(pre_hook_fn, inputs1={"self": self.model, "input_ids": input_ids, "logits": logits1, "labels": labels}, inputs2={"self": self.model, "input_ids": input_ids, "logits": logits2, "labels": labels}) if labels is not None: # loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if self.model.zo_custom_train_loss_fn: loss1, loss2 = self.task_compute_function(self.model.zo_custom_train_loss_fn, inputs1={"self": self.model, "input_ids": input_ids, "logits": logits1, "labels": labels, **kwargs}, inputs2={"self": self.model, "input_ids": input_ids, "logits": logits2, "labels": labels, **kwargs}) else: loss1, loss2 = self.task_compute_function( self.model.loss_function, inputs1={"logits": logits1, "labels": labels, "vocab_size": self.model.config.vocab_size, **kwargs}, inputs2={"logits": logits2, "labels": labels, "vocab_size": self.model.config.vocab_size, **kwargs}, ) if self.model.zo_train_loss_fn_post_hooks != []: for post_hook_fn in self.model.zo_train_loss_fn_post_hooks: (loss1, input_ids, logits1, labels), (loss2, input_ids, logits2, labels) = \ self.task_compute_function(post_hook_fn, inputs1={"self": self.model, "loss": loss1, "input_ids": input_ids, "logits": logits1, "labels": labels}, inputs2={"self": self.model, "loss": loss2, "input_ids": input_ids, "logits": logits2, "labels": labels}) return loss1, loss2 @torch.inference_mode def inner_zo_eval_forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> CausalLMOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) self.model.model.zo_training = False self.assign_zo2_attributes(self, self.model.model.opt) hidden_states = self.model.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, **kwargs, ) self.assign_zo2_attributes(self.model.model.opt, self) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep hidden_states = self.task_compute_function( fn_get_qwen3_sliced_logits_from_hidden_states, inputs1={"hidden_states": hidden_states, "slice_indices": slice_indices}, inputs2=None, ) logits = self.task_compute_module(self.model.lm_head, inputs1={"input": hidden_states}, inputs2=None, grad=None) if self.model.zo_eval_loss_fn_pre_hooks != []: for pre_hook_fn in self.model.zo_eval_loss_fn_pre_hooks: input_ids, logits, labels = \ self.task_compute_function(pre_hook_fn, inputs1=([self.model], {"input_ids": input_ids, "logits": logits, "labels": labels}), inputs2=None) if labels is not None: # loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if self.model.zo_custom_eval_loss_fn: loss = self.task_compute_function(self.model.zo_custom_eval_loss_fn, inputs1=([self.model], {"input_ids": input_ids, "logits": logits, "labels": labels, **kwargs}), inputs2=None) else: loss = self.task_compute_function( self.model.loss_function, inputs1={"logits": logits, "labels": labels, "vocab_size": self.model.config.vocab_size, **kwargs}, inputs2=None ) if self.model.zo_eval_loss_fn_post_hooks != []: for post_hook_fn in self.model.zo_eval_loss_fn_post_hooks: loss, input_ids, logits, labels = \ self.task_compute_function(post_hook_fn, inputs1=([self.model], {"loss": loss, "input_ids": input_ids, "logits": logits, "labels": labels}), inputs2=None) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=None, hidden_states=None, attentions=None, ) ================================================ FILE: zo2/model/huggingface/zo_init.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 from contextlib import contextmanager import torch import transformers from . import ( opt, # llama, qwen3 ) _zo2_supported_models = { transformers.OPTForCausalLM: opt.get_opt_for_causalLM, transformers.OPTForSequenceClassification: opt.get_opt_for_sequence_classification, transformers.OPTForQuestionAnswering: opt.get_opt_for_question_answering, # transformers.LlamaForCausalLM: llama.get_llama_for_causalLM, transformers.Qwen3ForCausalLM: qwen3.get_qwen3_for_causalLM, } @contextmanager def zo_hf_init(zo_config): try: for orig_class, get_zo2_class in _zo2_supported_models.items(): if hasattr(transformers, orig_class.__name__): zo2_class = get_zo2_class(zo_config) setattr(transformers, orig_class.__name__, zo2_class) else: raise NotImplementedError(f"Model '{orig_class.__name__}' is not supported in transformers.") yield finally: pass def main(): # user api: with zo_hf_init(zo_config): from transformers import OPTForCausalLM model = OPTForCausalLM.from_pretrained(...) model.zo_init(zo_config) print(type(model)) # should be zo2.OPTForCausalLM ================================================ FILE: zo2/model/nanogpt/__init__.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 from . import ( mezo_sgd, ) def get_nanogpt(zo_config): zo2_supported_configs = { "mezo-sgd": mezo_sgd.get_nanogpt_mezo_sgd, } return zo2_supported_configs[zo_config.zo_method](zo_config) ================================================ FILE: zo2/model/nanogpt/mezo_sgd/__init__.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 from ..model import GPTConfig, GPTConfigs, GPT from .zo import GPT as GPT_MeZOSGD from .zo2 import GPT as GPT_MeZO2SGD from ....config.mezo_sgd import MeZOSGDConfig def get_nanogpt_mezo_sgd(config: MeZOSGDConfig): return GPT_MeZO2SGD if config.zo2 else GPT_MeZOSGD ================================================ FILE: zo2/model/nanogpt/mezo_sgd/zo.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import torch import torch.nn.functional as F from .. import model from ...base import BaseZOModel from ....optimizer.mezo_sgd.zo import MeZOSGD from ....config.mezo_sgd import MeZOSGDConfig class GPT(model.GPT, BaseZOModel): def __init__(self, config: model.GPTConfig, zo_config: MeZOSGDConfig): super().__init__(config) self.opt = Optimizer(model=self, config=zo_config) def forward(self, idx, pos, targets=None): if self.zo_training: return self.opt.zo_forward(idx, pos, targets) else: # for evaluate and inference purpose return self.opt.zo_eval_forward(super().forward, idx, pos, targets) class Optimizer(MeZOSGD): @torch.inference_mode def inner_zo_forward(self, idx, pos, targets): tok_emb = self.model.transformer.wte(idx) pos_emb = self.model.transformer.wpe(pos) x = tok_emb + pos_emb for block in self.model.transformer.h: x = block(x) x = self.model.transformer.ln_f(x) x = self.model.lm_head(x) loss = F.cross_entropy( x.reshape(-1, x.size(-1)), targets.reshape(-1) ) return loss.detach() @torch.inference_mode() def inner_zo_eval_forward(self, eval_fn, idx, pos, targets): output = eval_fn(idx, pos, targets) return output ================================================ FILE: zo2/model/nanogpt/mezo_sgd/zo2.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import torch import torch.nn.functional as F import numpy as np from .. import model from ...base import BaseZOModel from ....optimizer.mezo_sgd.zo2 import MeZO2SGD from ....config.mezo_sgd import MeZOSGDConfig class GPT(model.GPT, BaseZOModel): def __init__(self, config: model.GPTConfig, zo_config: MeZOSGDConfig): super().__init__(config) self.opt = Optimizer(model=self, config=zo_config) def forward(self, idx, pos, targets=None): if self.zo_training: return self.opt.zo_forward(idx, pos, targets) else: # for evaluate and inference purpose return self.opt.zo_eval_forward(super().forward, idx, pos, targets) class Optimizer(MeZO2SGD): def init_zo2_upload(self): print("Upload head and tail to cuda.") self.model.transformer.wte = self.model.transformer.wte.to(self.device) self.model.transformer.wpe = self.model.transformer.wpe.to(self.device) self.model.transformer.ln_f = self.model.transformer.ln_f.to(self.device) self.model.lm_head = self.model.lm_head.to(self.device) self.num_blocks = len(self.model.transformer.h) if self.offloading_blocks is not None: self.offloading_blocks = self.offloading_blocks else: self.offloading_blocks = list(range(self.num_blocks)) print(f"Transformer blocks {self.offloading_blocks} will be offloaded to {self.offloading_device}") for i in range(self.num_blocks): if i in self.offloading_blocks: continue else: self.model.transformer.h[i] = self.model.transformer.h[i].to(self.device) print(f"Upload block {i} to cuda.") @torch.inference_mode() def inner_zo_forward(self, idx, pos, targets): we1, we2 = self.task_compute_module(self.model.transformer.wte, inputs1={"input": idx}, inputs2={"input": idx}, grad=self.projected_grad) pe1, pe2 = self.task_compute_module(self.model.transformer.wpe, {"input": pos}, {"input": pos}, self.projected_grad, compute_sync=False) hidden_states1, hidden_states2 = self.task_compute_function(torch.add, {"input": we1, "other": pe1}, {"input": we2, "other": pe2}, compute_sync=False) if 0 in self.offloading_blocks: self.model.transformer.h[0] = self.task_upload( module=self.model.transformer.h[0], device=self.device) N = len(self.model.transformer.h) for i in range(1, N): if i != 1: if i-2 in self.offloading_blocks: self.model.transformer.h[i-2] = self.task_offload( module=self.model.transformer.h[i-2], device=self.offloading_device) hidden_states1, hidden_states2 = self.task_compute_module( self.model.transformer.h[i-1], inputs1={"x": hidden_states1}, inputs2={"x": hidden_states2}, grad=self.projected_grad) if i in self.offloading_blocks: self.model.transformer.h[i] = self.task_upload( module=self.model.transformer.h[i], device=self.device) if N-2 in self.offloading_blocks: self.model.transformer.h[N-2] = self.task_offload( self.model.transformer.h[N-2], device=self.offloading_device) hidden_states1, hidden_states2 = self.task_compute_module( self.model.transformer.h[N-1], inputs1={"x": hidden_states1}, inputs2={"x": hidden_states2}, grad=self.projected_grad ) if N-1 in self.offloading_blocks: self.model.transformer.h[N-1] = self.task_offload( self.model.transformer.h[N-1], device=self.offloading_device) logits1, logits2 = self.task_compute_module(self.model.transformer.ln_f, inputs1={"input": hidden_states1}, inputs2={"input": hidden_states2}, grad=self.projected_grad, weight_decay=0.) # 'task_compute_module' will remove the first name 'ln_f', so we need to disable weight_decay manually. logits1, logits2 = self.task_compute_module(self.model.lm_head, inputs1={"input": logits1}, inputs2={"input": logits2}, grad=self.projected_grad) loss1, loss2 = self.task_compute_function(F.cross_entropy, {"input": logits1.reshape(-1, logits1.size(-1)), "target": targets.reshape(-1)}, {"input": logits2.reshape(-1, logits2.size(-1)), "target": targets.reshape(-1)}) return loss1, loss2 @torch.inference_mode() def inner_zo_eval_forward(self, eval_fn, idx, pos, targets): handles = self.add_zo2_eval_comm_hooks(self.model.transformer.h) output = eval_fn(idx, pos, targets) self.clear_zo2_eval_comm_hooks(handles) return output ================================================ FILE: zo2/model/nanogpt/model.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 """ Modified from https://github.com/karpathy/build-nanogpt/blob/master/train_gpt2.py """ import sys sys.path.append("./zo2") import math import inspect from dataclasses import dataclass import torch import torch.nn as nn from torch.nn import functional as F @dataclass class GPTConfig: block_size: int = 1024 vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency n_layer: int = 12 n_head: int = 12 n_embd: int = 768 dropout: float = 0.0 bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster class GPTConfigs: gpt2: GPTConfig = GPTConfig(n_layer=12, n_head=12, n_embd=768) gpt2_medium: GPTConfig = GPTConfig(n_layer=24, n_head=16, n_embd=1024) gpt2_large: GPTConfig = GPTConfig(n_layer=36, n_head=20, n_embd=1280) gpt2_xl: GPTConfig = GPTConfig(n_layer=48, n_head=25, n_embd=1600) opt_125m: GPTConfig = GPTConfig(n_layer=12, n_head=12, n_embd=768, block_size=2048) opt_350m: GPTConfig = GPTConfig(n_layer=24, n_head=16, n_embd=1024, block_size=2048) opt_1_3b: GPTConfig = GPTConfig(n_layer=24, n_head=32, n_embd=2048, block_size=2048) opt_2_7b: GPTConfig = GPTConfig(n_layer=32, n_head=32, n_embd=2560, block_size=2048) opt_6_7b: GPTConfig = GPTConfig(n_layer=32, n_head=32, n_embd=4096, block_size=2048) opt_13b: GPTConfig = GPTConfig(n_layer=40, n_head=40, n_embd=5120, block_size=2048) opt_30b: GPTConfig = GPTConfig(n_layer=48, n_head=56, n_embd=7168, block_size=2048) opt_66b: GPTConfig = GPTConfig(n_layer=64, n_head=72, n_embd=9216, block_size=2048) opt_175b: GPTConfig = GPTConfig(n_layer=96, n_head=96, n_embd=12288, block_size=2048) class LayerNorm(nn.Module): """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ def __init__(self, ndim, bias): super().__init__() self.weight = nn.Parameter(torch.ones(ndim)) self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None def forward(self, input): return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) class CausalSelfAttention(nn.Module): def __init__(self, config): super().__init__() assert config.n_embd % config.n_head == 0 # key, query, value projections for all heads, but in a batch self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) # output projection self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) # regularization self.attn_dropout = nn.Dropout(config.dropout) self.resid_dropout = nn.Dropout(config.dropout) self.n_head = config.n_head self.n_embd = config.n_embd self.dropout = config.dropout # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') if not self.flash: print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") # causal mask to ensure that attention is only applied to the left in the input sequence self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) .view(1, 1, config.block_size, config.block_size)) def forward(self, x): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim q, k, v = self.c_attn(x).split(self.n_embd, dim=2) k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) if self.flash: # efficient attention using Flash Attention CUDA kernels y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) else: # manual implementation of attention att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) att = F.softmax(att, dim=-1) att = self.attn_dropout(att) y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side # output projection y = self.resid_dropout(self.c_proj(y)) return y class MLP(nn.Module): def __init__(self, config): super().__init__() self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) self.gelu = nn.GELU() self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) self.dropout = nn.Dropout(config.dropout) def forward(self, x): x = self.c_fc(x) x = self.gelu(x) x = self.c_proj(x) x = self.dropout(x) return x class Block(nn.Module): def __init__(self, config): super().__init__() self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) self.attn = CausalSelfAttention(config) self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) self.mlp = MLP(config) def forward(self, x): x = x + self.attn(self.ln_1(x)) x = x + self.mlp(self.ln_2(x)) return x class GPT(nn.Module): def __init__(self, config): super().__init__() assert config.vocab_size is not None assert config.block_size is not None self.config = config self.transformer = nn.ModuleDict(dict( wte = nn.Embedding(config.vocab_size, config.n_embd), wpe = nn.Embedding(config.block_size, config.n_embd), drop = nn.Dropout(config.dropout), h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), ln_f = LayerNorm(config.n_embd, bias=config.bias), )) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # with weight tying when using torch.compile() some warnings get generated: # "UserWarning: functional_call was passed multiple values for tied weights. # This behavior is deprecated and will be an error in future versions" # not 100% sure what this is, so far seems to be harmless. TODO investigate # self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying # init all weights self.apply(self._init_weights) # apply special scaled init to the residual projections, per GPT-2 paper for pn, p in self.named_parameters(): if pn.endswith('c_proj.weight'): torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) # report number of parameters print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) def get_num_params(self, non_embedding=True): """ Return the number of parameters in the model. For non-embedding count (default), the position embeddings get subtracted. The token embeddings would too, except due to the parameter sharing these params are actually used as weights in the final layer, so we include them. """ n_params = sum(p.numel() for p in self.parameters()) if non_embedding: n_params -= self.transformer.wpe.weight.numel() return n_params def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, idx, pos, targets=None): # idx is of shape (B, T) B, T = idx.size() assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}" # forward the token and posisition embeddings pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd) tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd) x = tok_emb + pos_emb # forward the blocks of the transformer for block in self.transformer.h: x = block(x) # forward the final layernorm and the classifier x = self.transformer.ln_f(x) logits = self.lm_head(x) # (B, T, vocab_size) loss = None if targets is not None: loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1)) return logits, loss def crop_block_size(self, block_size): # model surgery to decrease the block size if necessary # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) # but want to use a smaller block size for some smaller, simpler model assert block_size <= self.config.block_size self.config.block_size = block_size self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) for block in self.transformer.h: if hasattr(block.attn, 'bias'): block.attn.bias = block.attn.bias[:,:,:block_size,:block_size] @classmethod def from_pretrained(cls, model_type, override_args=None): assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} override_args = override_args or {} # default to empty dict # only dropout can be overridden see more notes below assert all(k == 'dropout' for k in override_args) from transformers import GPT2LMHeadModel print("loading weights from pretrained gpt: %s" % model_type) # n_layer, n_head and n_embd are determined from model_type config_args = { 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params }[model_type] print("forcing vocab_size=50257, block_size=1024, bias=True") config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints config_args['bias'] = True # always True for GPT model checkpoints # we can override the dropout rate, if desired if 'dropout' in override_args: print(f"overriding dropout rate to {override_args['dropout']}") config_args['dropout'] = override_args['dropout'] # create a from-scratch initialized minGPT model config = GPTConfig(**config_args) model = GPT(config) sd = model.state_dict() sd_keys = sd.keys() sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param # init a huggingface/transformers model model_hf = GPT2LMHeadModel.from_pretrained(model_type) sd_hf = model_hf.state_dict() # copy while ensuring all of the parameters are aligned and match in names and shapes sd_keys_hf = sd_hf.keys() sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer) transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear # this means that we have to transpose these weights when we import them assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" for k in sd_keys_hf: if any(k.endswith(w) for w in transposed): # special treatment for the Conv1D weights we need to transpose assert sd_hf[k].shape[::-1] == sd[k].shape with torch.no_grad(): sd[k].copy_(sd_hf[k].t()) else: # vanilla copy over the other parameters assert sd_hf[k].shape == sd[k].shape with torch.no_grad(): sd[k].copy_(sd_hf[k]) return model def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): # start with all of the candidate parameters param_dict = {pn: p for pn, p in self.named_parameters()} # filter out those that do not require grad param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] optim_groups = [ {'params': decay_params, 'weight_decay': weight_decay}, {'params': nodecay_params, 'weight_decay': 0.0} ] num_decay_params = sum(p.numel() for p in decay_params) num_nodecay_params = sum(p.numel() for p in nodecay_params) print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") # Create AdamW optimizer and use the fused version if it is available fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters use_fused = fused_available and device_type == 'cuda' extra_args = dict(fused=True) if use_fused else dict() optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) print(f"using fused AdamW: {use_fused}") return optimizer def estimate_mfu(self, fwdbwd_per_iter, dt): """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ # first estimate the number of flops we do per iteration. # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 N = self.get_num_params() cfg = self.config L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size flops_per_token = 6*N + 12*L*H*Q*T flops_per_fwdbwd = flops_per_token * T flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter flops_achieved = flops_per_iter * (1.0/dt) # per second flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS mfu = flops_achieved / flops_promised return mfu @torch.no_grad() def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): """ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete the sequence max_new_tokens times, feeding the predictions back into the model each time. Most likely you'll want to make sure to be in model.eval() mode of operation for this. """ for _ in range(max_new_tokens): idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] logits, _ = self(idx_cond) logits = logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) idx = torch.cat((idx, idx_next), dim=1) return idx ================================================ FILE: zo2/optimizer/__init__.py ================================================ ================================================ FILE: zo2/optimizer/base.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import torch from torch.optim.optimizer import Optimizer class BaseOptimizer(Optimizer): """ Base class for Zeroth-Order Optimization handling basic setup, including learning rate management. This class is not intended for direct use but provides core functionalities for derived classes. """ def __init__(self, params, defaults): """ Initializes the BaseOptimizer. Args: params (iterable): Parameters to optimize or dicts defining parameter groups. defaults (dict): Default optimization options. """ super().__init__(params, defaults) self.lr = defaults["lr"] if len(self.param_groups) > 1: raise NotImplementedError("Currently ZO2 does not support multi-group optimizing.") def _update_lr(self): self.lr = self.param_groups[0]["lr"] def _set_lr(self): self.param_groups[0]["lr"] = self.lr ================================================ FILE: zo2/optimizer/mezo_sgd/__init__.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 ================================================ FILE: zo2/optimizer/mezo_sgd/utils/__init__.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 from .com import * from .comm import * ================================================ FILE: zo2/optimizer/mezo_sgd/utils/com.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 ================================================ FILE: zo2/optimizer/mezo_sgd/utils/comm.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import os import torch import torch.nn as nn def module_to_bucket_inplace(module: nn.Module): bucket = torch.cat([p.view(-1) for p in module.parameters()]) return bucket def bucket_to_module_inplace(bucket: torch.Tensor, module: nn.Module): offset = 0 for name, param in module.named_parameters(): num_elements = param.numel() new_param = bucket[offset: offset+num_elements].view_as(param) set_nested_attr(module, name, nn.Parameter(new_param, requires_grad=param.requires_grad)) offset += num_elements return module def create_disk_offload_path(path, module_id): if os.path.isfile(path): raise ValueError("'path' must be a dir.") elif os.path.isdir(path): file_path = os.path.join(path, module_id, 'tmp.pt') if not os.path.exists(path): os.makedirs(path) else: os.makedirs(path) file_path = os.path.join(path, module_id, 'tmp.pt') return file_path def get_disk_offload_path(path, module_id): return os.path.join(path, module_id, 'tmp.pt') def clear_disk_offload_path(path, module_id): disk_offload_path = os.path.join(path, module_id) if os.path.isdir(disk_offload_path): if not os.listdir(disk_offload_path): os.rmdir(disk_offload_path) def set_nested_attr(obj, attr, value): attrs = attr.split('.') for attr in attrs[:-1]: obj = getattr(obj, attr) setattr(obj, attrs[-1], value) ================================================ FILE: zo2/optimizer/mezo_sgd/zo.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import sys sys.path.append('./zo2') import torch import torch.nn as nn import torch.nn.functional as F from ..base import BaseOptimizer import numpy as np from ...config.mezo_sgd import MeZOSGDConfig class MeZOSGD(BaseOptimizer): """ Implements the [MeZO-SGD](https://arxiv.org/abs/2305.17333) optimization method, particularly suited for scenarios with limited compute resources. """ def __init__(self, model: nn.Module, config: MeZOSGDConfig): """ Initializes the MeZOSGD optimizer which applies zeroth-order optimization techniques to the model parameters. Args: model (nn.Module): The model whose parameters will be optimized. config (MeZOSGDConfig): Configuration object containing optimizer settings. """ self.config = config self.model = model self.lr = config.lr self.weight_decay = config.weight_decay self.zo_eps = config.eps self.max_zo_random_seed = config.max_zo_random_seed self.debug_mode = config.debug_mode defaults = dict( lr=self.lr, weight_decay=self.weight_decay, maximize=False, foreach=None, differentiable=False, fused=None, ) super().__init__(model.parameters(), defaults) @torch.inference_mode def zo_perturb_parameters(self, module: nn.Module, scaling_factor: float=1): """ Applies Gaussian noise to parameters of a module, facilitating zeroth-order optimization. Args: module (nn.Module): Module whose parameters will be perturbed. scaling_factor (float): Scaling factor for the noise applied to the parameters. """ for _, param in module.named_parameters(): if param.requires_grad: # Resample z if self.debug_mode: z = torch.ones_like(param.data) # for debug else: z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype) param.data.add_(scaling_factor * z * self.zo_eps) @torch.inference_mode def zo_update(self, module, weight_decay=None): """ Updates the parameters of a module based on zeroth-order perturbations and optional weight decay. Args: module (nn.Module): Module whose parameters will be updated. weight_decay (float, optional): Weight decay coefficient. If None, it defaults to the configuration. """ for name, param in module.named_parameters(): if param.requires_grad: # Resample z if self.debug_mode: z = torch.ones_like(param.data) # for debug else: z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype) if weight_decay != None: param.data.sub_( self.lr * (self.projected_grad * z + weight_decay * param.data)) else: if all(x not in name for x in ["bias", "layer_norm", "layernorm", "ln"]): param.data.sub_( self.lr * (self.projected_grad * z + self.weight_decay * param.data)) else: param.data.sub_(self.lr * self.projected_grad * z) def zo_perturb_shifts(self, first_perturb_shift=1, stride=2): """ Generates shifts for perturbing parameters in a pattern conducive to zeroth-order optimization. Returns: list: A list of perturb shifts used during the forward and update passes. """ return [first_perturb_shift, -stride, stride-first_perturb_shift] def compute_grad(self, loss1, loss2): return ((loss1 - loss2) / (2 * self.zo_eps)).item() @torch.inference_mode def zo_forward(self, *args, zo_random_seed: int=None, **kwargs): """ Forward pass that applies zeroth-order perturbations to compute the loss, used for gradient estimation. Notice that the application of Gaussian perturbations for the parameters during both the perturbation and update phases should be the same. Args: zo_random_seed (int, optional): Random seed for reproducibility of perturbations. """ self._update_lr() self.zo_random_seed = zo_random_seed if zo_random_seed else np.random.randint(self.max_zo_random_seed) torch.manual_seed(self.zo_random_seed) self.zo_perturb_parameters(self.model, scaling_factor=self.zo_perturb_shifts()[0]) loss1 = self.inner_zo_forward(*args, **kwargs) torch.manual_seed(self.zo_random_seed) self.zo_perturb_parameters(self.model, scaling_factor=self.zo_perturb_shifts()[1]) loss2 = self.inner_zo_forward(*args, **kwargs) self.projected_grad = self.compute_grad(loss1, loss2) torch.manual_seed(self.zo_random_seed) self.zo_perturb_parameters(self.model, scaling_factor=self.zo_perturb_shifts()[2]) torch.manual_seed(self.zo_random_seed) self.zo_update(self.model) return loss1 #*********************** evaluate ***********************# @torch.inference_mode() def zo_eval_forward(self, *args, **kwargs): """ Forward pass in evaluation mode. """ output = self.inner_zo_eval_forward(*args, **kwargs) return output #*********************** api ***********************# @torch.inference_mode def inner_zo_forward(self, idx, pos, targets): """ Example of ZO inner_zo_forward: Match the same args as the original model forward, and replace all 'self' to 'self.model'. """ tok_emb = self.model.transformer.wte(idx) pos_emb = self.model.transformer.wpe(pos) x = tok_emb + pos_emb for block in self.model.transformer.h: x = block(x) x = self.model.transformer.ln_f(x) x = self.model.lm_head(x) loss = F.cross_entropy( x[:, :-1, :].reshape(-1, x.size(-1)), targets[:, 1:].reshape(-1) ) return loss.detach() @torch.inference_mode() def inner_zo_eval_forward(self, eval_fn, idx, pos, targets): output = eval_fn(idx, pos, targets) return output ================================================ FILE: zo2/optimizer/mezo_sgd/zo2.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 import sys sys.path.append('./zo2') import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from collections import deque from .zo import MeZOSGD from ...config.mezo_sgd import MeZOSGDConfig from .utils import * class MeZO2SGD(MeZOSGD): first_call_eval = True # Class variable specifically for tracking eval function """ Extends MeZOSGD to support advanced offloading techniques that enhance the capability to train large models on systems with limited GPU memory. It manages the intricate balance between CPU and GPU, leveraging zeroth-order optimization with dynamic memory management through offloading. """ def __init__(self, model, config: MeZOSGDConfig): """ Initializes the MeZO2SGD optimizer, setting up the necessary configuration for offloading and optimization techniques. Args: model (nn.Module): The model whose parameters will be optimized. config (MeZOSGDConfig): Configuration object specifying optimizer settings including offloading and overlapping options. """ assert config.zo2, "MeZO2SGD can only work with offloading." super().__init__(model, config) self.device = config.working_device self.offloading_device = config.offloading_device self.overlap = config.overlap self.offloading_blocks = config.offloading_blocks self.compute_module_optimize_method = config.compute_module_optimize_method self.compute_function_optimize_method = config.compute_function_optimize_method self.communicate_optimize_method = config.communicate_optimize_method self.amp = config.amp self.amp_precision = config.amp_precision self.precision_on_offloading_device = config.precision_on_offloading_device self.precision_on_working_device = config.precision_on_working_device self.amp_compress_method = config.amp_compress_method self.init_zo2() def init_zo2(self): """ Sets up CUDA streams and initializes the offloading and uploading mechanisms required for efficient computation management across devices. """ self.upload_stream = torch.cuda.Stream() self.offload_stream = torch.cuda.Stream() self.compute_stream = torch.cuda.Stream() self.zo_random_seed = None self.rstate = None self.rstate_queue = deque(maxlen=2) self.last_rstate = None self.projected_grad = 0 self.init_zo2_upload() if self.amp: self.init_zo2_amp() def init_zo2_amp(self): """ Initializes the model parameters to use different precision levels based on their current device. This method works with Automatic Mixed Precision (AMP) by setting the precision for parameters based on whether they are located on the working device or the offloading device. """ working_device = torch.device(self.device) offloading_device = torch.device(self.offloading_device) for p in self.model.parameters(): if p.device == working_device: p.data = p.data.to(dtype=self.precision_on_working_device) elif p.device == offloading_device: p.data = p.data.to(dtype=self.precision_on_offloading_device) else: raise ValueError(f"Unsupported device found for parameter: {p.device}") def assign_zo2_attributes(self, source, target): """ Utility function to transfer ZO2 specific attributes from one module to another, aiding in maintaining consistency across nested model architectures. Args: source: The source module from which attributes are copied. target: The target module to which attributes are assigned. """ attrs_to_assign = ['upload_stream', 'offload_stream', 'compute_stream', 'zo_random_seed', 'rstate', 'rstate_queue', 'last_rstate', 'projected_grad'] for attr in attrs_to_assign: setattr(target, attr, getattr(source, attr)) @torch.inference_mode def zo_update(self, module, weight_decay=None): """ Applies the computed gradients to update parameters of the module, potentially including a weight decay term. This method is enhanced by managing CUDA state to ensure consistent random number generation across calls. Args: module (nn.Module): The module whose parameters are to be updated. weight_decay (float, optional): Optional weight decay for regularization. """ torch.cuda.set_rng_state(self.last_rstate) super().zo_update(module, weight_decay=weight_decay) self.last_rstate = torch.cuda.get_rng_state() return module @torch.inference_mode() def module_dual_forward(self, module, inputs1, inputs2, projected_grad=0., weight_decay=None): """ Performs two parallel forward computations with perturbed parameters to estimate gradients. This function is key for zeroth-order gradient estimation with support for optional weight decay during parameter update. Notice that the application of Gaussian perturbations for the parameters during both the perturbation and update phases should be the same. Args: module (nn.Module): The module on which forward passes are conducted. inputs1 (dict): Inputs for the first forward pass. inputs2 (dict): Inputs for the second forward pass. projected_grad (float): Projected gradient value used for updating parameters. weight_decay (float, optional): Optional weight decay for regularization. """ if projected_grad != 0: module = self.zo_update(module, weight_decay) torch.cuda.set_rng_state(self.rstate) self.zo_perturb_parameters(module, scaling_factor=self.zo_perturb_shifts()[0]) output1 = module(**inputs1) torch.cuda.set_rng_state(self.rstate) self.zo_perturb_parameters(module, scaling_factor=self.zo_perturb_shifts()[1]) output2 = module(**inputs2) torch.cuda.set_rng_state(self.rstate) self.zo_perturb_parameters(module, scaling_factor=self.zo_perturb_shifts()[2]) self.rstate = torch.cuda.get_rng_state() return output1, output2 @torch.inference_mode() def function_dual_forward(self, fn, inputs1, inputs2): """ Executes a provided function twice with dual inputs, supporting the zeroth-order optimization process by enabling the estimation of gradients through function outputs. Args: fn (callable): The function to be executed. inputs1 (dict): Arguments for the first execution of the function. inputs2 (dict): Arguments for the second execution of the function. Returns: tuple: Outputs from the two executions of the function. """ output1 = fn(**inputs1) output2 = fn(**inputs2) return output1, output2 @torch.inference_mode() def zo_forward(self, *args, seed: int=None, **kwargs): """ The overarching forward function that integrates perturbation, gradient estimation, and parameter update within a single coherent process, controlled by the seed for reproducibility. Args: seed (int, optional): Seed for random number generation to ensure reproducibility. """ self._update_lr() self.zo_random_seed = seed if seed else np.random.randint(self.max_zo_random_seed) torch.manual_seed(self.zo_random_seed) torch.cuda.manual_seed(self.zo_random_seed) self.rstate = torch.cuda.get_rng_state() self.rstate_queue.append(self.rstate.clone()) if len(self.rstate_queue) == 2: self.last_rstate = self.rstate_queue.popleft() torch.cuda.synchronize() # global sync to make sure all tasks finish loss1, loss2 = self.inner_zo_forward(*args, **kwargs) torch.cuda.synchronize() # global sync to make sure all tasks finish self.projected_grad = self.compute_grad(loss1, loss2) return loss1.detach() #*********************** tasks ***********************# def task_upload(self, module, device='cuda', upload_sync=False, *args, **kwargs): """ Handles the uploading of modules to the GPU, utilizing CUDA streams to potentially overlap computation and communication for efficiency. Args: module (nn.Module): Module to be uploaded. device (str): Target device for the upload. upload_sync (bool): Whether to synchronize the upload stream before proceeding. """ if self.overlap: if upload_sync: self.upload_stream.synchronize() with torch.cuda.stream(self.upload_stream if self.overlap else torch.cuda.current_stream()): module = self.upload_impl( module, device, self.offloading_device, self.communicate_optimize_method, non_blocking=self.overlap, *args, **kwargs ) return module def task_offload(self, module, device='cpu', offload_sync=False, *args, **kwargs): """ Manages the offloading of modules to an alternative storage (e.g., CPU or disk), using CUDA streams to manage dependencies and potentially overlap tasks. Args: module (nn.Module): Module to be offloaded. device (str): Target device for the offload. offload_sync (bool): Whether to synchronize the offload stream before proceeding. """ if self.overlap: if offload_sync: self.offload_stream.synchronize() self.compute_stream.synchronize() # offload depends on compute task with torch.cuda.stream(self.offload_stream if self.overlap else torch.cuda.current_stream()): module = self.offload_impl( module, device, self.offloading_device, self.communicate_optimize_method, non_blocking=self.overlap, *args, **kwargs ) return module def task_compute_module(self, module, inputs1, inputs2, grad, compute_sync=False, weight_decay=None, *args, **kwargs): """ Conducts computations on a module with optional dual inputs for gradient estimation, applying synchronization and CUDA streams for efficiency. Args: module (nn.Module): The module on which computations are to be performed. inputs1 (dict): Inputs for the first computation. inputs2 (dict, could be None): Inputs for the second computation, if performing dual forward. grad (float): Gradient value to be applied. compute_sync (bool): Whether to synchronize the compute stream before proceeding. weight_decay (float, optional): Optional weight decay during the update. """ if self.overlap: if compute_sync: self.compute_stream.synchronize() self.upload_stream.synchronize() # module compute depends on upload task with torch.cuda.stream(self.compute_stream if self.overlap else torch.cuda.current_stream()): if inputs2 is not None: return self.compute_module_impl( self.module_dual_forward, module, self.compute_module_optimize_method, inputs1=inputs1, inputs2=inputs2, projected_grad=grad, weight_decay=weight_decay, *args, **kwargs ) elif isinstance(inputs1, list): return self.compute_module_impl( None, module, self.compute_module_optimize_method, *inputs1, *args, **kwargs ) elif isinstance(inputs1, dict): return self.compute_module_impl( None, module, self.compute_module_optimize_method, *args, **inputs1, **kwargs ) elif isinstance(inputs1, tuple): return self.compute_module_impl( None, module, self.compute_module_optimize_method, *inputs1[0], *args, **inputs1[1], **kwargs ) else: raise ValueError("Invalid inputs type.") def task_compute_function(self, fn, inputs1, inputs2, compute_sync=False, *args, **kwargs): """ Executes a provided function with dual input sets to facilitate parallel operations and gradient estimation. This method integrates CUDA streams for efficient task execution. Args: fn (callable): The function to execute, typically a PyTorch operation or custom function. inputs1 (dict): Arguments for the first execution of the function. inputs2 (dict, could be None): Arguments for the second execution of the function. compute_sync (bool): Whether to synchronize the compute stream before execution to ensure data readiness. """ if self.overlap: if compute_sync: self.compute_stream.synchronize() with torch.cuda.stream(self.compute_stream if self.overlap else torch.cuda.current_stream()): if inputs2 is not None: return self.compute_function_impl( self.function_dual_forward, fn, self.compute_function_optimize_method, inputs1=inputs1, inputs2=inputs2, *args, **kwargs ) elif isinstance(inputs1, list): return self.compute_function_impl( None, fn, self.compute_function_optimize_method, *inputs1, *args, **kwargs ) elif isinstance(inputs1, dict): return self.compute_function_impl( None, fn, self.compute_function_optimize_method, *args, **inputs1, **kwargs ) elif isinstance(inputs1, tuple): return self.compute_function_impl( None, fn, self.compute_function_optimize_method, *inputs1[0], *args, **inputs1[1], **kwargs ) else: raise ValueError("Invalid inputs type.") #*********************** evaluate ***********************# @torch.inference_mode() def zo_eval_forward(self, *args, **kwargs): """ Conducts a model evaluation using the internal forward method without applying any perturbations. This method ensures all tasks finish before and after the evaluation to maintain synchronization. Args: *args, **kwargs: Arguments and keyword arguments for the model's forward method. """ if MeZO2SGD.first_call_eval: print("Warning: ZO2 may not efficiently optimize the evaluation stage, which could result in slower performance.") MeZO2SGD.first_call_eval = False # Disable the warning after the first call torch.cuda.synchronize() # global sync to make sure all tasks finish output = self.inner_zo_eval_forward(*args, **kwargs) torch.cuda.synchronize() # global sync to make sure all tasks finish return output def add_zo2_eval_comm_hooks(self, blocks): """ Attaches communication hooks to model blocks to manage data uploading and offloading during evaluation. This helps in managing memory more efficiently during the eval phase. Args: blocks (list): List of model blocks to attach hooks to. Returns: list: A list of hook handles for managing lifecycle. """ handles = [] for block in blocks: if isinstance(block, nn.Module): pre_handle = block.register_forward_pre_hook(self.eval_upload_hook) post_handle = block.register_forward_hook(self.eval_offload_hook) handles.append(pre_handle) handles.append(post_handle) return handles def clear_zo2_eval_comm_hooks(self, handles): """ Removes communication hooks from model blocks after evaluation to clean up and prevent memory leaks. Args: handles (list): List of hook handles to be removed. """ for handle in handles: handle.remove() def eval_upload_hook(self, module, input): """ A forward pre-hook to upload a module to the GPU before its evaluation. Args: module (nn.Module): Module to be uploaded. input: Input data for the module. """ self.upload_impl( module, self.device, self.offloading_device ) return input def eval_offload_hook(self, module, input, output): """ A forward hook to offload a module from the GPU after its evaluation to free up memory. Args: module (nn.Module): Module to be offloaded. input: Input data for the module. output: Output from the module evaluation. """ if self.overlap: with torch.cuda.stream(self.offload_stream): self.offload_impl( module, self.offloading_device, self.offloading_device ) else: self.offload_impl( module, self.offloading_device, self.offloading_device ) return output #*********************** backend ***********************# def upload_impl( self, module: nn.Module, device: str, offloading_device: str, optimize_method: str = "", module_id: str = None, *args, **kwargs ): """ Implements the logic for uploading model components to a specified device. Supports various optimization methods to tailor the upload process for different computing environments. """ def _upload_impl(module, device, offloading_device, *args, **kwargs): if offloading_device == "cpu": module = module.to(device, *args, **kwargs) else: if module_id == None: raise ValueError("For disk offloading mode, 'module_id' cannot be None.") offloading_disk_path = get_disk_offload_path(offloading_device, module_id) match type(module): case torch.Tensor: module = torch.load(offloading_disk_path, map_location=device) case nn.Module: module.load_state_dict(torch.load(offloading_disk_path, map_location=device)) case _: raise ValueError clear_disk_offload_path(offloading_device, module_id) return module match optimize_method: case "": module = _upload_impl(module, device, offloading_device, *args, **kwargs) case "bucket": # works on large-scale models bucket = module_to_bucket_inplace(module) bucket = _upload_impl(bucket, device, offloading_device, *args, **kwargs) module = bucket_to_module_inplace(bucket, module) case _: raise NotImplementedError if self.amp: # after uploading, decompress the module to higher precision module = self.amp_decompress_impl(module) return module def offload_impl( self, module: nn.Module, device: str, offloading_device: str, optimize_method: str = "", module_id: str = None, *args, **kwargs ): """ Implements the logic for offloading model components from the GPU to another storage, such as CPU or disk, to manage GPU memory more efficiently. """ def _offload_impl(module, device, offloading_device, *args, **kwargs): if offloading_device == "cpu": module = module.to(device, *args, **kwargs) else: if module_id == None: raise ValueError("For disk offloading mode, 'module_id' cannot be None.") offloading_disk_path = create_disk_offload_path(offloading_device, module_id) match type(module): case torch.Tensor: torch.save(module, offloading_disk_path) case nn.Module: torch.save(module.state_dict(), offloading_disk_path) case _: raise ValueError return module if self.amp: # before offloading, compress the module to lower precision module = self.amp_compress_impl(module) match optimize_method: case "": module = _offload_impl(module, device, offloading_device, *args, **kwargs) case "bucket": # works on large-scale models bucket = module_to_bucket_inplace(module) bucket = _offload_impl(bucket, device, offloading_device, *args, **kwargs) module = bucket_to_module_inplace(bucket, module) case _: raise NotImplementedError return module def compute_module_impl( self, forward_fn, module: torch.nn.Module, optimize_method: str, *args, optimize_kwargs = None, **kwargs ): """ Manages the computation tasks on a module, applying various optimization methods to enhance execution speed and efficiency. """ match optimize_method: case "": pass case "torch.compile": # may introduce some precision mismatch module = torch.compile(module, **optimize_kwargs) case _: raise NotImplementedError with torch.autocast(device_type=self.device, dtype=self.amp_precision, enabled=self.amp): if forward_fn is None: return module(*args, **kwargs) else: return forward_fn(module=module, *args, **kwargs) def compute_function_impl( self, function_fn, fn, optimize_method: str, *args, optimize_kwargs = None, **kwargs ): """ Manages the computation tasks on a function, applying various optimization methods to enhance function execution speed and efficiency. """ match optimize_method: case "": pass case "torch.jit.script": # may introduce some precision mismatch fn = torch.jit.script(fn, **optimize_kwargs) case _: raise NotImplementedError with torch.autocast(device_type=self.device, dtype=self.amp_precision, enabled=self.amp): if function_fn is None: return fn(*args, **kwargs) else: return function_fn(fn, *args, **kwargs) def amp_decompress_impl(self, module: nn.Module) -> nn.Module: """ Converts the data type of module parameters to a higher precision typically used for computations. This is part of the AMP process where parameters might be temporarily compressed to a lower precision and need to be decompressed back to higher precision for accuracy-critical operations. Args: module (nn.Module): The module whose parameters will be decompressed. Returns: nn.Module: The module with parameters converted to higher precision. """ for p in module.parameters(): match self.amp_compress_method: case "naive": p.data = p.data.to(dtype=self.precision_on_working_device) case _: raise NotImplementedError return module def amp_compress_impl(self, module: nn.Module) -> nn.Module: """ Compresses the data type of module parameters to a lower precision typically used to save memory and improve computational efficiency during less accuracy-critical operations. Args: module (nn.Module): The module whose parameters will be compressed. Returns: nn.Module: The module with parameters converted to lower precision. """ for p in module.parameters(): match self.amp_compress_method: case "naive": p.data = p.data.to(dtype=self.precision_on_offloading_device) case _: raise NotImplementedError return module #*********************** api ***********************# def init_zo2_upload(self): """ Initializes the upload of essential model components to the GPU. This method specifically handles the uploading of model embeddings and head components, and prepares the offloading blocks based on configuration. This setup is crucial for managing the active memory footprint during training by selectively uploading and offloading transformer blocks as needed. """ print("Upload head and tail to cuda.") self.model.transformer.wte = self.model.transformer.wte.to(self.device) self.model.transformer.wpe = self.model.transformer.wpe.to(self.device) self.model.transformer.ln_f = self.model.transformer.ln_f.to(self.device) self.model.lm_head = self.model.lm_head.to(self.device) self.num_blocks = len(self.model.transformer.h) if self.offloading_blocks is not None: self.offloading_blocks = self.offloading_blocks else: self.offloading_blocks = list(range(self.num_blocks)) print(f"Transformer blocks {self.offloading_blocks} will be offloaded to {self.offloading_device}") for i in range(self.num_blocks): if i in self.offloading_blocks: continue else: self.model.transformer.h[i] = self.model.transformer.h[i].to(self.device) print(f"Upload block {i} to cuda.") @torch.inference_mode() def inner_zo_forward(self, idx, pos, targets): """ Defines the inner forward logic for zeroth-order optimization, applying perturbations and calculating the loss for gradient estimation. This method, using nanogpt as an example, orchestrates the forward computation across potentially offloaded transformer blocks, ensuring they are uploaded for computation and offloaded post-computation as configured. Args: idx (Tensor): Input indices for token embeddings. pos (Tensor): Position indices for positional embeddings. targets (Tensor): Target outputs for loss calculation. Returns: Tuple[Tensor, Tensor]: The losses computed from two perturbed forward passes, used for gradient estimation. """ we1, we2 = self.task_compute_module(self.model.transformer.wte, inputs1={"input": idx}, inputs2={"input": idx}, grad=self.projected_grad) pe1, pe2 = self.task_compute_module(self.model.transformer.wpe, {"input": pos}, {"input": pos}, self.projected_grad) hidden_states1, hidden_states2 = self.task_compute_function(torch.add, {"input": we1, "other": pe1}, {"input": we2, "other": pe2}) if 0 in self.offloading_blocks: self.model.transformer.h[0] = self.task_upload( module=self.model.transformer.h[0], device=self.device) N = len(self.model.transformer.h) for i in range(1, N): if i != 1: if i-2 in self.offloading_blocks: self.model.transformer.h[i-2] = self.task_offload( module=self.model.transformer.h[i-2], device=self.offloading_device) hidden_states1, hidden_states2 = self.task_compute_module( self.model.transformer.h[i-1], inputs1={"x": hidden_states1}, inputs2={"x": hidden_states2}, grad=self.projected_grad) if i in self.offloading_blocks: self.model.transformer.h[i] = self.task_upload( module=self.model.transformer.h[i], device=self.device) if N-2 in self.offloading_blocks: self.model.transformer.h[N-2] = self.task_offload( self.model.transformer.h[N-2], device=self.offloading_device) hidden_states1, hidden_states2 = self.task_compute_module( self.model.transformer.h[N-1], inputs1={"x": hidden_states1}, inputs2={"x": hidden_states2}, grad=self.projected_grad ) if N-1 in self.offloading_blocks: self.model.transformer.h[N-1] = self.task_offload( self.model.transformer.h[N-1], device=self.offloading_device) logits1, logits2 = self.task_compute_module(self.model.transformer.ln_f, inputs1={"input": hidden_states1}, inputs2={"input": hidden_states2}, grad=self.projected_grad, weight_decay=0.) logits1, logits2 = self.task_compute_module(self.model.lm_head, inputs1={"input": logits1}, inputs2={"input": logits2}, grad=self.projected_grad) loss1, loss2 = self.task_compute_function(F.cross_entropy, {"input": logits1[:, :-1, :].reshape(-1, logits1.size(-1)), "target": targets[:, 1:].reshape(-1)}, {"input": logits2[:, :-1, :].reshape(-1, logits2.size(-1)), "target": targets[:, 1:].reshape(-1)}) return loss1, loss2 @torch.inference_mode() def inner_zo_eval_forward(self, eval_fn, idx, pos, targets): """ Conducts an evaluation forward pass of the model using the zeroth-order optimization setup, but without applying any perturbations to ensure accurate performance assessment. This function manages the dynamic uploading and offloading of transformer blocks as needed, utilizing pre- and post-hooks to optimize memory usage during evaluation. Args: eval_fn (callable): The evaluation function to be applied, typically involves a forward pass that computes the loss or other metrics without updating model parameters. idx (Tensor): Input indices for token embeddings. pos (Tensor): Position indices for positional embeddings. targets (Tensor): Target outputs for computing the evaluation metric (e.g., loss). Returns: Tensor: The output from the evaluation function, typically loss or accuracy metrics. """ handles = self.add_zo2_eval_comm_hooks(self.model.transformer.h) output = eval_fn(idx, pos, targets) self.clear_zo2_eval_comm_hooks(handles) return output ================================================ FILE: zo2/trainer/__init__.py ================================================ ================================================ FILE: zo2/trainer/hf_transformers/__init__.py ================================================ from .trainer import ZOTrainer ================================================ FILE: zo2/trainer/hf_transformers/trainer.py ================================================ # Copyright 2020-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task. """ import contextlib import copy import functools import glob import importlib.metadata import inspect import json import math import os import random import re import shutil import sys import tempfile import time import warnings from collections.abc import Mapping from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Optional, Union # Integrations must be imported before ML frameworks: # isort: off from transformers.integrations import ( get_reporting_integration_callbacks, ) # isort: on import huggingface_hub.utils as hf_hub_utils import numpy as np import torch import torch.distributed as dist from huggingface_hub import ModelCard, create_repo, upload_folder from packaging import version from torch import nn from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler from transformers import Trainer from transformers import __version__ from transformers.configuration_utils import PretrainedConfig from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from transformers.debug_utils import DebugOption, DebugUnderflowOverflow from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor from transformers.feature_extraction_utils import FeatureExtractionMixin from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend from transformers.image_processing_utils import BaseImageProcessor from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available from transformers.integrations.tpu import tpu_spmd_dataloader from transformers.modelcard import TrainingSummary from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from transformers.models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES, ) from transformers.optimization import Adafactor, get_scheduler from transformers.processing_utils import ProcessorMixin from transformers.pytorch_utils import ( ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_2_3, ) from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.trainer_callback import ( CallbackHandler, DefaultFlowCallback, ExportableState, PrinterCallback, ProgressCallback, TrainerCallback, TrainerControl, TrainerState, ) from transformers.trainer_pt_utils import ( DistributedTensorGatherer, EvalLoopContainer, IterableDatasetShard, LabelSmoother, LayerWiseDummyOptimizer, LengthGroupedSampler, SequentialDistributedSampler, distributed_broadcast_scalars, distributed_concat, find_batch_size, get_model_param_count, get_module_class_from_name, get_parameter_names, nested_concat, nested_detach, nested_numpify, nested_xla_mesh_reduce, reissue_pt_warnings, remove_dummy_checkpoint, set_rng_state_for_device, ) from transformers.trainer_utils import ( PREFIX_CHECKPOINT_DIR, BestRun, EvalLoopOutput, EvalPrediction, HPSearchBackend, HubStrategy, PredictionOutput, RemoveColumnsCollator, SaveStrategy, TrainerMemoryTracker, TrainOutput, check_target_module_exists, default_compute_objective, denumpify_detensorize, enable_full_determinism, find_executable_batch_size, get_last_checkpoint, has_length, neftune_post_forward_hook, number_of_arguments, seed_worker, set_seed, speed_metrics, ) from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments from transformers.utils import ( ADAPTER_CONFIG_NAME, ADAPTER_SAFE_WEIGHTS_NAME, ADAPTER_WEIGHTS_NAME, CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, XLA_FSDPV2_MIN_VERSION, PushInProgress, PushToHubMixin, can_return_loss, find_labels, is_accelerate_available, is_apex_available, is_apollo_torch_available, is_bitsandbytes_available, is_datasets_available, is_galore_torch_available, is_grokadamw_available, is_in_notebook, is_ipex_available, is_liger_kernel_available, is_lomo_available, is_peft_available, is_safetensors_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_schedulefree_available, is_torch_compile_available, is_torch_hpu_available, is_torch_mlu_available, is_torch_mps_available, is_torch_musa_available, is_torch_neuroncore_available, is_torch_npu_available, is_torch_xla_available, is_torch_xpu_available, is_torchao_available, logging, strtobool, ) from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.quantization_config import QuantizationMethod DEFAULT_CALLBACKS = [DefaultFlowCallback] DEFAULT_PROGRESS_CALLBACK = ProgressCallback if is_in_notebook(): from transformers.utils.notebook import NotebookProgressCallback DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback if is_apex_available(): from apex import amp if is_datasets_available(): import datasets if is_torch_xla_available(): import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met from torch_xla import __version__ as XLA_VERSION IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION) if IS_XLA_FSDPV2_POST_2_2: import torch_xla.distributed.spmd as xs import torch_xla.runtime as xr else: IS_XLA_FSDPV2_POST_2_2 = False if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp from smdistributed.modelparallel import __version__ as SMP_VERSION IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat else: IS_SAGEMAKER_MP_POST_1_10 = False if is_safetensors_available(): import safetensors.torch if is_peft_available(): from peft import PeftModel if is_accelerate_available(): from accelerate import Accelerator, skip_first_batches from accelerate import __version__ as accelerate_version from accelerate.state import AcceleratorState from accelerate.utils import ( AutocastKwargs, DistributedDataParallelKwargs, DistributedType, load_fsdp_model, load_fsdp_optimizer, save_fsdp_model, save_fsdp_optimizer, ) DATA_SAMPLERS = [RandomSampler] if version.parse(accelerate_version) > version.parse("1.3.0"): from accelerate.utils import TorchTensorParallelPlugin if version.parse(accelerate_version) > version.parse("0.23.0"): from accelerate.data_loader import SeedableRandomSampler DATA_SAMPLERS += [SeedableRandomSampler] if is_deepspeed_available(): from accelerate.utils import DeepSpeedSchedulerWrapper if is_accelerate_available("0.28.0"): from accelerate.utils import DataLoaderConfiguration def _is_peft_model(model): if is_peft_available(): classes_to_check = (PeftModel,) if is_peft_available() else () # Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321 if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"): from peft import PeftMixedModel classes_to_check = (*classes_to_check, PeftMixedModel) return isinstance(model, classes_to_check) return False def _get_fsdp_ckpt_kwargs(): # TODO: @AjayP13, @younesbelkada replace this check with version check at the next `accelerate` release if is_accelerate_available() and "adapter_only" in list(inspect.signature(save_fsdp_model).parameters): return {"adapter_only": True} else: return {} def safe_globals(): # Starting from version 2.4 PyTorch introduces a check for the objects loaded # with torch.load(weights_only=True). Starting from 2.6 weights_only=True becomes # a default and requires allowlisting of objects being loaded. # See: https://github.com/pytorch/pytorch/pull/137602 # See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals # See: https://github.com/huggingface/accelerate/pull/3036 if version.parse(torch.__version__).release < version.parse("2.6").release: return contextlib.nullcontext() np_core = np._core if version.parse(np.__version__) >= version.parse("2.0.0") else np.core allowlist = [np_core.multiarray._reconstruct, np.ndarray, np.dtype] # numpy >1.25 defines numpy.dtypes.UInt32DType, but below works for # all versions of numpy allowlist += [type(np.dtype(np.uint32))] return torch.serialization.safe_globals(allowlist) if TYPE_CHECKING: import optuna if is_datasets_available(): import datasets logger = logging.get_logger(__name__) # Name of the files used for checkpointing TRAINING_ARGS_NAME = "training_args.bin" TRAINER_STATE_NAME = "trainer_state.json" OPTIMIZER_NAME = "optimizer.pt" SCALER_NAME = "scaler.pt" OPTIMIZER_NAME_BIN = "optimizer.bin" SCHEDULER_NAME = "scheduler.pt" FSDP_MODEL_NAME = "pytorch_model_fsdp" class ZOTrainer(Trainer): # Those are used as methods of the Trainer in examples. from transformers.trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state def __init__( self, model: Union[PreTrainedModel, nn.Module, None] = None, args: TrainingArguments = None, data_collator: Optional[DataCollator] = None, train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None, eval_dataset: Optional[Union[Dataset, dict[str, Dataset], "datasets.Dataset"]] = None, processing_class: Optional[ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] ] = None, model_init: Optional[Callable[[], PreTrainedModel]] = None, compute_loss_func: Optional[Callable] = None, compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None, preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, ): super().__init__(model, args, data_collator, train_dataset, eval_dataset, processing_class, model_init, compute_loss_func, compute_metrics, callbacks, optimizers, optimizer_cls_and_kwargs, preprocess_logits_for_metrics) # ZO2 added: if using ZO2: if hasattr(model, "zo_training"): print("ZO training mode is enabled.") self.zo = True else: self.zo = False # ZO2 added: currently unsupported conditions if self.zo: self._zo2_unsupported_conditions(args) # ZO2 added: init hooks buffer if self.zo: self.zo2_training_step_pre_hooks = [] self.zo2_training_step_post_hooks = [] def _inner_training_loop( self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None ): self.accelerator.free_memory() self._train_batch_size = batch_size if self.args.auto_find_batch_size: if self.state.train_batch_size != self._train_batch_size: from accelerate.utils import release_memory (self.model_wrapped,) = release_memory(self.model_wrapped) self.model_wrapped = self.model # Check for DeepSpeed *after* the initial pass and modify the config if self.is_deepspeed_enabled: # Temporarily unset `self.args.train_batch_size` original_bs = self.args.per_device_train_batch_size self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu) self.propagate_args_to_deepspeed(True) self.args.per_device_train_batch_size = original_bs self.state.train_batch_size = self._train_batch_size logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") # Data loader and number of training steps train_dataloader = self.get_train_dataloader() if self.is_fsdp_xla_v2_enabled: train_dataloader = tpu_spmd_dataloader(train_dataloader) # Setting up training control variables: # number of training epochs: num_train_epochs # number of training steps per epoch: num_update_steps_per_epoch # total number of training steps to execute: max_steps total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size ( num_train_epochs, num_update_steps_per_epoch, num_examples, num_train_samples, epoch_based, len_dataloader, max_steps, ) = self.set_initial_training_values(args, train_dataloader, total_train_batch_size) num_train_tokens = None if self.args.include_tokens_per_second: num_train_tokens = self.num_tokens(train_dataloader, None if epoch_based else max_steps) # If going by epochs, multiply tokens linearly if len_dataloader is not None and epoch_based: num_train_tokens *= args.num_train_epochs # Otherwise since its steps, we just multiply by grad accum else: num_train_tokens *= args.gradient_accumulation_steps if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: if self.args.n_gpu > 1: # nn.DataParallel(model) replicates the model, creating new variables and module # references registered here no longer work on other gpus, breaking the module raise ValueError( "Currently --debug underflow_overflow is not supported under DP. Please use DDP" " (torchrun or torch.distributed.launch (deprecated))." ) else: debug_overflow = DebugUnderflowOverflow(self.model) # noqa delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled # Can't delay optimizer creation when using FSDP2: https://github.com/huggingface/accelerate/blob/3f636d626063ffcf9a337c7d3624d61b7d187d59/src/accelerate/accelerator.py#L1404 is_fsdp2 = self.is_fsdp_enabled and (getattr(self.accelerator.state.fsdp_plugin, "fsdp_version", 1) == 2) if is_fsdp2: delay_optimizer_creation = False # We need to reset the scheduler, as its parameters may be different on subsequent calls if self._created_lr_scheduler: self.lr_scheduler = None self._created_lr_scheduler = False if self.is_deepspeed_enabled: self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) if not delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) self.state = TrainerState( stateful_callbacks=[ cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) ] ) self.state.is_hyper_param_search = trial is not None self.state.train_batch_size = self._train_batch_size # Compute absolute values for logging, eval, and save if given as ratio self.state.compute_steps(args, max_steps) # Activate gradient checkpointing if needed if args.gradient_checkpointing: self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs) # ZO2 added -> # model = self._wrap_model(self.model_wrapped) model = self.model # as the model is wrapped, don't use `accelerator.prepare` # this is for unhandled cases such as # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX use_accelerator_prepare = True if model is self.model else False # ZO2 added -> use_accelerator_prepare = False if use_accelerator_prepare and self.is_fsdp_enabled: # In case of auto_find_batch_size=True # Remove FSDP wrapping from sub-models. self.model = unwrap_model(self.model, recursive=True) if delay_optimizer_creation: if use_accelerator_prepare: # configure fsdp plugin for qlora if any self._fsdp_qlora_plugin_updates() if self.accelerator.mixed_precision != "fp8": self.model = self.accelerator.prepare(self.model) self.create_optimizer_and_scheduler(num_training_steps=max_steps) # prepare using `accelerator` prepare if use_accelerator_prepare: self.model.train() if hasattr(self.lr_scheduler, "step"): if self.use_apex: model = self.accelerator.prepare(self.model) else: model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) else: # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( self.model, self.optimizer, self.lr_scheduler ) elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: # In this case we are in DDP + LOMO, which should be supported self.optimizer = self.accelerator.prepare(self.optimizer) if self.is_fsdp_enabled: self.model = self.model_wrapped = model # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: self.model_wrapped = model # backward compatibility if self.is_deepspeed_enabled: self.deepspeed = self.model_wrapped # ZO2 added -> if delay_optimizer_creation: if self.zo: self.create_optimizer_and_scheduler(num_training_steps=max_steps, model=model) else: self.create_optimizer_and_scheduler(num_training_steps=max_steps) # ZO2 added -> # Check if saved optimizer or scheduler states exist if self.zo: _, model = self._load_optimizer_and_scheduler(resume_from_checkpoint, model) else: self._load_optimizer_and_scheduler(resume_from_checkpoint) # # ckpt loading # if resume_from_checkpoint is not None: # if self.is_deepspeed_enabled: # deepspeed_load_checkpoint( # self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model) # ) # elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled: # self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) # # Check if saved optimizer or scheduler states exist # self._load_optimizer_and_scheduler(resume_from_checkpoint) # self._load_scaler(resume_from_checkpoint) # important: at this point: # self.model is the Transformers Model # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc. # Train! logger.info("***** Running training *****") logger.info(f" Num examples = {num_examples:,}") logger.info(f" Num Epochs = {num_train_epochs:,}") logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") if self.args.per_device_train_batch_size != self._train_batch_size: logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_steps:,}") logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") self.state.epoch = 0 start_time = time.time() epochs_trained = 0 steps_trained_in_current_epoch = 0 steps_trained_progress_bar = None # Check if continuing training from a checkpoint if resume_from_checkpoint is not None and os.path.isfile( os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) ): self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) self.compare_trainer_and_checkpoint_args(self.args, self.state) self._load_callback_state() epochs_trained = int(self.state.global_step // num_update_steps_per_epoch) if not args.ignore_data_skip: steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) steps_trained_in_current_epoch *= args.gradient_accumulation_steps else: steps_trained_in_current_epoch = 0 logger.info(" Continuing training from checkpoint, will skip to saved global_step") logger.info(f" Continuing training from epoch {epochs_trained}") logger.info(f" Continuing training from global step {self.state.global_step}") if not args.ignore_data_skip: logger.info( f" Will skip the first {epochs_trained} epochs then the first" f" {steps_trained_in_current_epoch} batches in the first epoch." ) # Update the references for attr in ("model", "optimizer", "lr_scheduler"): setattr(self.callback_handler, attr, getattr(self, attr)) self.callback_handler.train_dataloader = train_dataloader self.state.init_training_references(self, max_steps, num_train_epochs, trial) # tr_loss is a tensor to avoid synchronization of TPUs through .item() tr_loss = torch.tensor(0.0, device=args.device) # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses self._total_loss_scalar = 0.0 self._globalstep_last_logged = self.state.global_step model.zero_grad() grad_norm: Optional[float] = None learning_rate = None self.control = self.callback_handler.on_train_begin(args, self.state, self.control) if args.eval_on_start: self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) for epoch in range(epochs_trained, num_train_epochs): epoch_dataloader = train_dataloader if hasattr(epoch_dataloader, "set_epoch"): epoch_dataloader.set_epoch(epoch) # Reset the past mems state at the beginning of each epoch if necessary. if args.past_index >= 0: self._past = None steps_in_epoch = ( len(epoch_dataloader) if len_dataloader is not None else args.max_steps * args.gradient_accumulation_steps ) self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False steps_skipped = 0 if steps_trained_in_current_epoch > 0: epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch) steps_skipped = steps_trained_in_current_epoch steps_trained_in_current_epoch = 0 rng_to_sync = True step = -1 epoch_iterator = iter(epoch_dataloader) # We chunkify the epoch iterator into gradient accumulation steps `n` batches remainder = num_examples % args.gradient_accumulation_steps if remainder == 0: remainder = args.gradient_accumulation_steps update_step = -1 total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1 if args.gradient_accumulation_steps == 1: total_updates -= 1 for _ in range(total_updates): update_step += 1 num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device) for i, inputs in enumerate(batch_samples): step += 1 do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch # Since we perform prefetching, we need to manually set sync_gradients self.accelerator.gradient_state._set_sync_gradients(do_sync_step) if self.args.include_num_input_tokens_seen: main_input_name = getattr(self.model, "main_input_name", "input_ids") if main_input_name not in inputs: logger.warning( "Tried to track the number of tokens seen, however the current model is " "not configured properly to know what item is the input. To fix this, add " "a `main_input_name` attribute to the model class you are using." ) else: input_tokens = inputs[main_input_name].numel() input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64) self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item() if rng_to_sync: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 if steps_trained_progress_bar is not None: steps_trained_progress_bar.update(1) if steps_trained_in_current_epoch == 0: self._load_rng_state(resume_from_checkpoint) continue elif steps_trained_progress_bar is not None: steps_trained_progress_bar.close() steps_trained_progress_bar = None if step % args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(args, self.state, self.control) # ZO2 added -> estimate gradient and updates if self.zo: tr_loss_step = self.zo2_training_step(model, inputs) else: # We explicitly want to avoid relying on `accelerator.accumulate` for generation training context = ( functools.partial(self.accelerator.no_sync, model=model) if i != len(batch_samples) - 1 and self.accelerator.distributed_type != DistributedType.DEEPSPEED else contextlib.nullcontext ) with context(): tr_loss_step = self.training_step(model, inputs, num_items_in_batch) if ( args.logging_nan_inf_filter and not is_torch_xla_available() and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) ): # if loss is nan or inf simply add the average of previous logged losses tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) else: if tr_loss.device != tr_loss_step.device: raise ValueError( f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}" ) tr_loss = tr_loss + tr_loss_step self.current_flos += float(self.floating_point_ops(inputs)) if do_sync_step: # ZO2 added -> ignore parameter update since it is fuesd with model forward if self.zo: pass else: # Since we perform prefetching, we need to manually set sync_gradients to True self.accelerator.gradient_state._set_sync_gradients(True) # Gradient clipping if args.max_grad_norm is not None and args.max_grad_norm > 0: if is_sagemaker_mp_enabled() and args.fp16: _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm) elif self.use_apex: # Revert to normal clipping otherwise, handling Apex or full precision _grad_norm = nn.utils.clip_grad_norm_( amp.master_params(self.optimizer), args.max_grad_norm, ) else: _grad_norm = self.accelerator.clip_grad_norm_( model.parameters(), args.max_grad_norm, ) if ( is_accelerate_available() and self.accelerator.distributed_type == DistributedType.DEEPSPEED ): grad_norm = model.get_global_grad_norm() # In some cases the grad norm may not return a float if hasattr(grad_norm, "item"): grad_norm = grad_norm.item() else: grad_norm = _grad_norm self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) self.optimizer.step() self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) # get leaning rate before update learning_rate = self._get_learning_rate() if not self.accelerator.optimizer_step_was_skipped: # Delay optimizer scheduling until metrics are generated if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.lr_scheduler.step() model.zero_grad() self.state.global_step += 1 self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch self.control = self.callback_handler.on_step_end(args, self.state, self.control) self._maybe_log_save_evaluate( tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=learning_rate, ) else: self.control = self.callback_handler.on_substep_end(args, self.state, self.control) # PyTorch/XLA relies on the data loader to insert the mark_step for # each step. Since we are breaking the loop early, we need to manually # insert the mark_step here. if self.control.should_epoch_stop or self.control.should_training_stop: if is_torch_xla_available(): xm.mark_step() break # We also need to break out of the nested loop if self.control.should_epoch_stop or self.control.should_training_stop: if is_torch_xla_available(): xm.mark_step() break if step < 0: logger.warning( "There seems not to be a single sample in your epoch_iterator, stopping training at step" f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" f" num_steps ({max_steps}) higher than the number of available samples." ) self.control.should_training_stop = True self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) self._maybe_log_save_evaluate( tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=learning_rate ) if DebugOption.TPU_METRICS_DEBUG in self.args.debug: if is_torch_xla_available(): # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) else: logger.warning( "You enabled PyTorch/XLA debug metrics but you don't have a TPU " "configured. Check your training configuration if this is unexpected." ) if self.control.should_training_stop: break if args.past_index and hasattr(self, "_past"): # Clean the state at the end of training delattr(self, "_past") logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: # Wait for everyone to get here so we are sure the model has been saved by process 0. if is_torch_xla_available(): xm.rendezvous("load_best_model_at_end") elif args.parallel_mode == ParallelMode.DISTRIBUTED: dist.barrier() elif is_sagemaker_mp_enabled(): smp.barrier() self._load_best_model() # add remaining tr_loss self._total_loss_scalar += tr_loss.item() effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError train_loss = self._total_loss_scalar / effective_global_step metrics = speed_metrics( "train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps, num_tokens=num_train_tokens, ) self.store_flos() metrics["total_flos"] = self.state.total_flos metrics["train_loss"] = train_loss self.is_in_train = False self._memory_tracker.stop_and_update_metrics(metrics) self.log(metrics) run_dir = self._get_output_dir(trial) checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save. if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: for checkpoint in checkpoints_sorted: if not os.path.samefile(checkpoint, self.state.best_model_checkpoint): logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") shutil.rmtree(checkpoint, ignore_errors=True) self.control = self.callback_handler.on_train_end(args, self.state, self.control) # Wait for the checkpoint to be uploaded. self._finish_current_push() # After training we make sure to retrieve back the original forward pass method # for the embedding layer by removing the forward post hook. if self.neftune_noise_alpha is not None: self._deactivate_neftune(self.model) return TrainOutput(self.state.global_step, train_loss, metrics) def _load_optimizer_and_scheduler(self, checkpoint, model=None): """ disable the optimizer resume. """ output = super()._load_optimizer_and_scheduler(checkpoint) if self.zo and model is not None: model.opt = self.optimizer return output, model return output def create_optimizer_and_scheduler(self, num_training_steps: int, model: nn.Module=None): """ disable the optimizer but leave the learning rate scheduler. """ if not self.zo: self.create_optimizer() if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16: # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer optimizer = self.optimizer.optimizer else: optimizer = self.optimizer else: if model is None: optimizer = self.optimizer = self.model.opt else: optimizer = self.optimizer = model.opt self.create_scheduler(num_training_steps, optimizer) def _move_model_to_device(self, model, device): pass #*********************** zo2 functions ***********************# def _zo2_unsupported_conditions(self, args): if args.gradient_accumulation_steps > 1: raise NotImplementedError if args.n_gpu > 1: raise NotImplementedError("Currently ZO2 only support one working device") if args.deepspeed: raise NotImplementedError if is_sagemaker_mp_enabled(): raise NotImplementedError if args.torch_compile: raise NotImplementedError def register_zo2_training_step_pre_hook(self, hook_fn): """ example: def print_zo_info(model, inputs): tqdm.write("projected grad: {}".format(model.opt.projected_grad)) return model, inputs trainer = ZOTrainer(...) trainer.register_zo2_training_step_pre_hook(print_zo_info) """ self.zo2_training_step_pre_hooks.append(hook_fn) def register_zo2_training_step_post_hook(self, hook_fn): """ example: def drop_invalid_data(model, inputs, loss): # Extract projected_grad, handle both tensor and scalar cases projected_grad = model.opt.projected_grad if isinstance(projected_grad, torch.Tensor): projected_grad_is_nan = torch.isnan(projected_grad).any() else: projected_grad_is_nan = projected_grad != projected_grad # Check for NaN in scalars if torch.isnan(loss) or projected_grad_is_nan: tqdm.write("'loss': {} or 'projected_grad': {} is nan. Drop this step.".format( loss, model.opt.projected_grad )) model.opt.projected_grad = 0 # Reset projected_grad to prevent parameter updates return model, inputs, loss trainer = ZOTrainer(...) trainer.register_zo2_training_step_post_hook(drop_invalid_data) """ self.zo2_training_step_post_hooks.append(hook_fn) def zo2_training_step(self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: if self.zo2_training_step_pre_hooks != []: for pre_hook_fn in self.zo2_training_step_pre_hooks: model, inputs = pre_hook_fn(model, inputs) model.zo_train() inputs = self._prepare_inputs(inputs) loss = model(**inputs) model.zo_eval() if self.zo2_training_step_post_hooks != []: for post_hook_fn in self.zo2_training_step_post_hooks: model, inputs, loss = post_hook_fn(model, inputs, loss) return loss ================================================ FILE: zo2/trainer/hf_trl/__init__.py ================================================ from .sft_trainer import ZOSFTTrainer ================================================ FILE: zo2/trainer/hf_trl/sft_trainer.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from trl import SFTTrainer import contextlib import functools import glob import inspect import math import os import random import re import shutil import sys import time import warnings from collections.abc import Mapping from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import copy import numpy as np from functools import wraps from tqdm.auto import tqdm from transformers import Trainer, DataCollator from sklearn.linear_model import LinearRegression, LogisticRegression, LogisticRegressionCV # Integrations must be imported before ML frameworks: from transformers.integrations import ( # isort: split default_hp_search_backend, get_reporting_integration_callbacks, hp_params, is_fairscale_available, is_optuna_available, is_ray_tune_available, is_sigopt_available, is_wandb_available, run_hp_search_optuna, run_hp_search_ray, run_hp_search_sigopt, run_hp_search_wandb, ) import numpy as np import torch import torch.distributed as dist from packaging import version from torch import nn from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler from huggingface_hub import Repository from transformers import __version__ from transformers.configuration_utils import PretrainedConfig from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from transformers.debug_utils import DebugOption, DebugUnderflowOverflow from transformers.deepspeed import deepspeed_init, is_deepspeed_zero3_enabled from transformers.dependency_versions_check import dep_version_check from transformers.modelcard import TrainingSummary from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES from transformers.optimization import Adafactor, get_scheduler from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10, is_torch_less_than_1_11 from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.trainer_callback import ( CallbackHandler, DefaultFlowCallback, PrinterCallback, ProgressCallback, TrainerCallback, TrainerControl, TrainerState, ) from transformers.trainer_pt_utils import ( DistributedLengthGroupedSampler, DistributedSamplerWithLoop, DistributedTensorGatherer, IterableDatasetShard, LabelSmoother, LengthGroupedSampler, SequentialDistributedSampler, ShardSampler, distributed_broadcast_scalars, distributed_concat, find_batch_size, get_module_class_from_name, get_parameter_names, nested_concat, nested_detach, nested_numpify, nested_truncate, nested_xla_mesh_reduce, reissue_pt_warnings, ) from transformers.trainer_utils import ( PREFIX_CHECKPOINT_DIR, BestRun, EvalLoopOutput, EvalPrediction, FSDPOption, HPSearchBackend, HubStrategy, IntervalStrategy, PredictionOutput, RemoveColumnsCollator, ShardedDDPOption, TrainerMemoryTracker, TrainOutput, default_compute_objective, default_hp_space, denumpify_detensorize, enable_full_determinism, find_executable_batch_size, get_last_checkpoint, has_length, number_of_arguments, seed_worker, set_seed, speed_metrics, ) from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments from transformers.utils import ( CONFIG_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, find_labels, get_full_repo_name, is_apex_available, is_datasets_available, is_in_notebook, is_ipex_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tensorrt_fx_available, is_torch_tpu_available, is_accelerate_available, is_torchdynamo_available, logging, ) from transformers.utils.generic import ContextManagers from transformers.trainer_pt_utils import ( _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state, get_model_param_count, ) _is_native_cpu_amp_available = is_torch_greater_or_equal_than_1_10 DEFAULT_CALLBACKS = [DefaultFlowCallback] DEFAULT_PROGRESS_CALLBACK = ProgressCallback if is_in_notebook(): from transformers.utils.notebook import NotebookProgressCallback DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback if is_apex_available(): from apex import amp if is_datasets_available(): import datasets if is_torch_tpu_available(check_device=False): import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met import torch_xla.distributed.parallel_loader as pl if is_fairscale_available(): dep_version_check("fairscale") import fairscale from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP from fairscale.nn.wrap import auto_wrap from fairscale.optim import OSS from fairscale.optim.grad_scaler import ShardedGradScaler if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp from smdistributed.modelparallel import __version__ as SMP_VERSION IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat else: IS_SAGEMAKER_MP_POST_1_10 = False skip_first_batches = None if is_accelerate_available(): from accelerate import __version__ as accelerate_version if version.parse(accelerate_version) >= version.parse("0.16"): from accelerate import skip_first_batches if TYPE_CHECKING: import optuna logger = logging.get_logger(__name__) # Name of the files used for checkpointing TRAINING_ARGS_NAME = "training_args.bin" TRAINER_STATE_NAME = "trainer_state.json" OPTIMIZER_NAME = "optimizer.pt" SCHEDULER_NAME = "scheduler.pt" SCALER_NAME = "scaler.pt" class ZOSFTTrainer(SFTTrainer): def __init__( self, model: Union[PreTrainedModel, nn.Module, str], args: TrainingArguments = None, data_collator: Optional[DataCollator] = None, train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, model_init: Optional[Callable[[], PreTrainedModel]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, callbacks: Optional[List[TrainerCallback]] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, peft_config: Optional["PeftConfig"] = None, dataset_text_field: Optional[str] = None, packing: Optional[bool] = False, formatting_func: Optional[Callable] = None, max_seq_length: Optional[int] = None, infinite: Optional[bool] = False, num_of_sequences: Optional[int] = 1024, chars_per_token: Optional[float] = 3.6, dataset_num_proc: Optional[int] = None, dataset_batch_size: int = 1000, neftune_noise_alpha: Optional[float] = None, model_init_kwargs: Optional[Dict] = None, ): # ZO2 added: if using ZO2: if hasattr(model, "zo_training"): print("ZO training mode is enabled.") self.zo = True else: self.zo = False # ZO2 added: currently unsupported conditions if self.zo: self._zo2_unsupported_conditions(args) # ZO2 added: init hooks buffer if self.zo: self.zo2_training_step_pre_hooks = [] self.zo2_training_step_post_hooks = [] super().__init__(model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics, peft_config, dataset_text_field, packing, formatting_func, max_seq_length, infinite, num_of_sequences, chars_per_token, dataset_num_proc, dataset_batch_size, neftune_noise_alpha, model_init_kwargs) def _inner_training_loop(self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None): """ We overload the original training loop to add ZO2. Search key word "ZO2 added" for those updates. """ self._train_batch_size = batch_size # Data loader and number of training steps train_dataloader = self.get_train_dataloader() # Setting up training control variables: # number of training epochs: num_train_epochs # number of training steps per epoch: num_update_steps_per_epoch # total number of training steps to execute: max_steps total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size len_dataloader = None if has_length(train_dataloader): len_dataloader = len(train_dataloader) num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) num_examples = self.num_examples(train_dataloader) if args.max_steps > 0: max_steps = args.max_steps num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( args.max_steps % num_update_steps_per_epoch > 0 ) # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's # the best we can do. num_train_samples = args.max_steps * total_train_batch_size else: max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) num_train_epochs = math.ceil(args.num_train_epochs) num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size max_steps = args.max_steps # Setting a very large number of epochs so we go as many times as necessary over the iterator. num_train_epochs = sys.maxsize num_update_steps_per_epoch = max_steps num_examples = total_train_batch_size * args.max_steps num_train_samples = args.max_steps * total_train_batch_size else: raise ValueError( "args.max_steps must be set to a positive value if dataloader does not have a length, was" f" {args.max_steps}" ) if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: if self.args.n_gpu > 1: # nn.DataParallel(model) replicates the model, creating new variables and module # references registered here no longer work on other gpus, breaking the module raise ValueError( "Currently --debug underflow_overflow is not supported under DP. Please use DDP" " (torch.distributed.launch)." ) else: debug_overflow = DebugUnderflowOverflow(self.model) # noqa delay_optimizer_creation = ( self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE or is_sagemaker_mp_enabled() or self.fsdp is not None ) if args.deepspeed: deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint ) self.model = deepspeed_engine.module self.model_wrapped = deepspeed_engine self.deepspeed = deepspeed_engine self.optimizer = optimizer self.lr_scheduler = lr_scheduler elif not delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) self.state = TrainerState() self.state.is_hyper_param_search = trial is not None # Activate gradient checkpointing if needed if args.gradient_checkpointing: self.model.gradient_checkpointing_enable() # ZO2 added -> # model = self._wrap_model(self.model_wrapped) model = self.model if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None: self._load_from_checkpoint(resume_from_checkpoint, model) # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: self.model_wrapped = model # ZO2 added -> if delay_optimizer_creation: if self.zo: self.create_optimizer_and_scheduler(num_training_steps=max_steps, model=model) else: self.create_optimizer_and_scheduler(num_training_steps=max_steps) # ZO2 added -> # Check if saved optimizer or scheduler states exist if self.zo: _, model = self._load_optimizer_and_scheduler(resume_from_checkpoint, model) else: self._load_optimizer_and_scheduler(resume_from_checkpoint) # important: at this point: # self.model is the Transformers Model # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc. # Train! logger.info("***** Running training *****") logger.info(f" Num examples = {num_examples:,}") logger.info(f" Num Epochs = {num_train_epochs:,}") logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size:,}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_steps:,}") logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") self.state.epoch = 0 start_time = time.time() epochs_trained = 0 steps_trained_in_current_epoch = 0 steps_trained_progress_bar = None # Check if continuing training from a checkpoint if resume_from_checkpoint is not None and os.path.isfile( os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) ): self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) epochs_trained = self.state.global_step // num_update_steps_per_epoch if not args.ignore_data_skip: steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) steps_trained_in_current_epoch *= args.gradient_accumulation_steps else: steps_trained_in_current_epoch = 0 logger.info(" Continuing training from checkpoint, will skip to saved global_step") logger.info(f" Continuing training from epoch {epochs_trained}") logger.info(f" Continuing training from global step {self.state.global_step}") if not args.ignore_data_skip: if skip_first_batches is None: logger.info( f" Will skip the first {epochs_trained} epochs then the first" f" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time," " you can install the latest version of Accelerate with `pip install -U accelerate`.You can" " also add the `--ignore_data_skip` flag to your launch command, but you will resume the" " training on data already seen by your model." ) else: logger.info( f" Will skip the first {epochs_trained} epochs then the first" f" {steps_trained_in_current_epoch} batches in the first epoch." ) if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None: steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch) steps_trained_progress_bar.set_description("Skipping the first batches") # Update the references self.callback_handler.model = self.model self.callback_handler.optimizer = self.optimizer self.callback_handler.lr_scheduler = self.lr_scheduler self.callback_handler.train_dataloader = train_dataloader if self.hp_name is not None and self._trial is not None: # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial # parameter to Train when using DDP. self.state.trial_name = self.hp_name(self._trial) if trial is not None: assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial self.state.trial_params = hp_params(assignments) else: self.state.trial_params = None # This should be the same if the state has been saved but in case the training arguments changed, it's safer # to set this after the load. self.state.max_steps = max_steps self.state.num_train_epochs = num_train_epochs self.state.is_local_process_zero = self.is_local_process_zero() self.state.is_world_process_zero = self.is_world_process_zero() # tr_loss is a tensor to avoid synchronization of TPUs through .item() tr_loss = torch.tensor(0.0).to(args.device) # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses self._total_loss_scalar = 0.0 self._globalstep_last_logged = self.state.global_step model.zero_grad() self.control = self.callback_handler.on_train_begin(args, self.state, self.control) # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. if not args.ignore_data_skip: for epoch in range(epochs_trained): is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance( train_dataloader.sampler, RandomSampler ) if is_torch_less_than_1_11 or not is_random_sampler: # We just need to begin an iteration to create the randomization of the sampler. # That was before PyTorch 1.11 however... for _ in train_dataloader: break else: # Otherwise we need to call the whooooole sampler cause there is some random operation added # AT THE VERY END! _ = list(train_dataloader.sampler) total_batched_samples = 0 for epoch in range(epochs_trained, num_train_epochs): if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch) elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard): train_dataloader.dataset.set_epoch(epoch) if is_torch_tpu_available(): parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device) epoch_iterator = parallel_loader else: epoch_iterator = train_dataloader # Reset the past mems state at the beginning of each epoch if necessary. if args.past_index >= 0: self._past = None steps_in_epoch = ( len(epoch_iterator) if len_dataloader is not None else args.max_steps * args.gradient_accumulation_steps ) self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False steps_skipped = 0 if skip_first_batches is not None and steps_trained_in_current_epoch > 0: epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) steps_skipped = steps_trained_in_current_epoch steps_trained_in_current_epoch = 0 rng_to_sync = True step = -1 for step, inputs in enumerate(epoch_iterator): total_batched_samples += 1 if rng_to_sync: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 if steps_trained_progress_bar is not None: steps_trained_progress_bar.update(1) if steps_trained_in_current_epoch == 0: self._load_rng_state(resume_from_checkpoint) continue elif steps_trained_progress_bar is not None: steps_trained_progress_bar.close() steps_trained_progress_bar = None if step % args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(args, self.state, self.control) # ZO2 added -> estimate gradient and updates if self.zo: tr_loss_step = self.zo2_training_step(model, inputs) else: if ( (total_batched_samples % args.gradient_accumulation_steps != 0) and args.local_rank != -1 and args._no_sync_in_gradient_accumulation ): # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. with model.no_sync(): tr_loss_step = self.training_step(model, inputs) else: tr_loss_step = self.training_step(model, inputs) if ( args.logging_nan_inf_filter and not is_torch_tpu_available() and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) ): # if loss is nan or inf simply add the average of previous logged losses tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) else: tr_loss += tr_loss_step self.current_flos += float(self.floating_point_ops(inputs)) # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps if self.deepspeed: self.deepspeed.step() if total_batched_samples % args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch ): # ZO2 added -> ignore parameter update since it is fuesd with model forward if self.zo: pass else: # Gradient clipping if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed: # deepspeed does its own clipping if self.do_grad_scaling: # Reduce gradients first for XLA if is_torch_tpu_available(): gradients = xm._fetch_gradients(self.optimizer) xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) # AMP: gradients need unscaling self.scaler.unscale_(self.optimizer) if is_sagemaker_mp_enabled() and args.fp16: self.optimizer.clip_master_grads(args.max_grad_norm) elif hasattr(self.optimizer, "clip_grad_norm"): # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping self.optimizer.clip_grad_norm(args.max_grad_norm) elif hasattr(model, "clip_grad_norm_"): # Some models (like FullyShardedDDP) have a specific way to do gradient clipping model.clip_grad_norm_(args.max_grad_norm) else: # Revert to normal clipping otherwise, handling Apex or full precision nn.utils.clip_grad_norm_( amp.master_params(self.optimizer) if self.use_apex else model.parameters(), args.max_grad_norm, ) # Optimizer step optimizer_was_run = True if self.deepspeed: pass # called outside the loop elif is_torch_tpu_available(): if self.do_grad_scaling: self.scaler.step(self.optimizer) self.scaler.update() else: xm.optimizer_step(self.optimizer) elif self.do_grad_scaling: scale_before = self.scaler.get_scale() self.scaler.step(self.optimizer) self.scaler.update() scale_after = self.scaler.get_scale() optimizer_was_run = scale_before <= scale_after else: self.optimizer.step() if optimizer_was_run and not self.deepspeed: self.lr_scheduler.step() model.zero_grad() self.state.global_step += 1 self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch self.control = self.callback_handler.on_step_end(args, self.state, self.control) self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) else: self.control = self.callback_handler.on_substep_end(args, self.state, self.control) if self.control.should_epoch_stop or self.control.should_training_stop: break if step < 0: logger.warning( "There seems to be not a single sample in your epoch_iterator, stopping training at step" f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" f" num_steps ({max_steps}) higher than the number of available samples." ) self.control.should_training_stop = True self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) if DebugOption.TPU_METRICS_DEBUG in self.args.debug: if is_torch_tpu_available(): # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) else: logger.warning( "You enabled PyTorch/XLA debug metrics but you don't have a TPU " "configured. Check your training configuration if this is unexpected." ) if self.control.should_training_stop: break if args.past_index and hasattr(self, "_past"): # Clean the state at the end of training delattr(self, "_past") logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: # Wait for everyone to get here so we are sur the model has been saved by process 0. if is_torch_tpu_available(): xm.rendezvous("load_best_model_at_end") elif args.local_rank != -1: dist.barrier() elif is_sagemaker_mp_enabled(): smp.barrier() self._load_best_model() # add remaining tr_loss self._total_loss_scalar += tr_loss.item() train_loss = self._total_loss_scalar / self.state.global_step metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) self.store_flos() metrics["total_flos"] = self.state.total_flos metrics["train_loss"] = train_loss self.is_in_train = False self._memory_tracker.stop_and_update_metrics(metrics) self.log(metrics) run_dir = self._get_output_dir(trial) checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save. if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: for checkpoint in checkpoints_sorted: if checkpoint != self.state.best_model_checkpoint: logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") shutil.rmtree(checkpoint) self.control = self.callback_handler.on_train_end(args, self.state, self.control) return TrainOutput(self.state.global_step, train_loss, metrics) @wraps(Trainer.train) def train(self, *args, **kwargs): """ ZO2 does not support neftune. """ # # Activate neftune right before training. # if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune: # self.model = self._trl_activate_neftune(self.model) output = Trainer.train(self, *args, **kwargs) # # After training we make sure to retrieve back the original forward pass method # # for the embedding layer by removing the forward post hook. # if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune: # unwrapped_model = unwrap_model(self.model) # if is_peft_available() and isinstance(unwrapped_model, PeftModel): # embeddings = unwrapped_model.base_model.model.get_input_embeddings() # else: # embeddings = unwrapped_model.get_input_embeddings() # self.neftune_hook_handle.remove() # del embeddings.neftune_noise_alpha return output def _load_optimizer_and_scheduler(self, checkpoint, model=None): """ disable the optimizer resume. """ output = super()._load_optimizer_and_scheduler(checkpoint) if self.zo and model is not None: model.opt = self.optimizer return output, model return output def create_optimizer_and_scheduler(self, num_training_steps: int, model: nn.Module=None): """ disable the optimizer but leave the learning rate scheduler. """ if not self.zo: self.create_optimizer() if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16: # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer optimizer = self.optimizer.optimizer else: optimizer = self.optimizer else: if model is None: optimizer = self.optimizer = self.model.opt else: optimizer = self.optimizer = model.opt self.create_scheduler(num_training_steps, optimizer) def _move_model_to_device(self, model, device): pass #*********************** zo2 functions ***********************# def _zo2_unsupported_conditions(self, args): if args.gradient_accumulation_steps > 1: raise NotImplementedError if args.n_gpu > 1: raise NotImplementedError("Currently ZO2 only support one working device") if args.deepspeed: raise NotImplementedError if is_torch_tpu_available(check_device=False): raise NotImplementedError if is_fairscale_available(): raise NotImplementedError if is_sagemaker_mp_enabled(): raise NotImplementedError if args.torch_compile: raise NotImplementedError def register_zo2_training_step_pre_hook(self, hook_fn): self.zo2_training_step_pre_hooks.append(hook_fn) def register_zo2_training_step_post_hook(self, hook_fn): self.zo2_training_step_post_hooks.append(hook_fn) def zo2_training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: if self.zo2_training_step_pre_hooks != []: for pre_hook_fn in self.zo2_training_step_pre_hooks: model, inputs = pre_hook_fn(model, inputs) model.zo_train() inputs = self._prepare_inputs(inputs) loss = model(**inputs) model.zo_eval() if self.zo2_training_step_post_hooks != []: for post_hook_fn in self.zo2_training_step_post_hooks: model, inputs, loss = post_hook_fn(model, inputs, loss) return loss ================================================ FILE: zo2/utils/__init__.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 from .utils import seed_everything ================================================ FILE: zo2/utils/utils.py ================================================ # Copyright (c) 2025 liangyuwang # Licensed under the Apache License, Version 2.0 from torch import nn import torch import os import random import numpy as np def print_all(module: nn.Module, inputs, outputs): print("Param: ") for p in module.parameters(): print(p.min().item(), p.max().item(), p.mean().item()) print("Inputs: ") if isinstance(inputs, torch.Tensor): print(inputs.min().item(), inputs.max().item()) else: for _, input in inputs.items(): if isinstance(input, torch.Tensor): print(input.min().item(), input.max().item()) print("Output: ") if isinstance(outputs, torch.Tensor): print(outputs.min().item(), outputs.max().item(), outputs.mean().item()) else: print("Unrecongized outputs.") print("*" * 20) def print_hook(module, input, output): print(module, f"{module.weight.min().item():.4f}, {module.weight.max().item():.4f}") print(f"{output.min().item():.8f} {output.max().item():.8f} {output.mean().item():.8f}") def print_para_and_device(model): for p, v in model.named_parameters(): print(f"{p}: {v.device}") def cal_self_reg_loss(logits, labels): loss = nn.CrossEntropyLoss()( logits[:, :-1, :].reshape(-1, logits.size(-1)), labels[:, 1:].reshape(-1) ) return loss def seed_everything(seed): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False