[
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# ZO2 (Zeroth-Order Offloading): Full Parameter Fine-Tuning 175B LLMs with 18GB GPU Memory\n\n[![arXiv](https://img.shields.io/badge/Arxiv-2503.12668-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2503.12668)\n[![License](https://img.shields.io/badge/License-Apache%202.0-yellow)](https://github.com/liangyuwang/zo2/blob/main/LICENSE)\n[![GitDiagran](https://img.shields.io/badge/Git-Diagram%20-blue)](https://gitdiagram.com/liangyuwang/zo2)\n[![DeepWiki](https://img.shields.io/badge/Devin-DeepWiki%20-green)](https://deepwiki.com/liangyuwang/zo2)\n<!-- <a target=\"_blank\" href=\"https://colab.research.google.com/github/liangyuwang/zo2/blob/main/tutorial/colab.ipynb\">\n  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n</a> -->\n\n👋 Welcome! **ZO2** is an innovative framework specifically designed to enhance the fine-tuning of large language models (LLMs) using **zeroth-order (ZO)** optimization techniques and advanced **offloading** technologies. This framework is particularly tailored for setups with limited GPU memory (e.g. fine-tune **[OPT-175B](https://arxiv.org/abs/2205.01068)** with just **18GB GPU memory**), enabling the fine-tuning of models that were previously unmanageable due to hardware constraints.\n\n- The table below displays the GPU memory usage for various OPT model sizes when fine-tuned using the ZO2 framework:\n\n|        OPT Models        |   1.3B   |   2.7B   |   6.7B   |   13B   |   30B   |    66B    |        175B        |\n| :-----------------------: | :------: | :------: | :------: | :------: | :------: | :-------: | :-----------------: |\n| **GPU memory (GB)** | `3.75` | `4.14` | `4.99` | `6.18` | `8.86` | `12.07` | **`18.04`** |\n\n- [Install](#️installation) the package and execute the following test to see the memory usage:\n\n```shell\n  bash test/mezo_sgd/hf_opt/record_zo2_memory.sh\n```\n\n## 📰 News\n\n- 16/07/2025: ZO2 is accepted by [COLM](https://colmweb.org/index.html).\n- 02/05/2025: Added support for [Qwen3](https://qwenlm.github.io/blog/qwen3/). You can now fully fine-tune the [32B version](https://huggingface.co/Qwen/Qwen3-32B-FP8) with just 6GB GPU memory using ZO2. Please refer to our [example](example/mezo_runner/).\n- 01/05/2025: We upgraded the environment and dependencies to align with the latest `transformers==4.51.3`. \n- 06/03/2025: We have open-sourced ZO2!\n\n## 💡 Key Features\n\n- **Optimized ZO CPU Offloading**: ZO2 leverages `zeroth-order (ZO)` methods to efficiently use `CPU offloading`, avoiding redundant data transfers and significantly reducing GPU memory demands. This allows for handling large-scale models on hardware with limited GPU resources.\n- **Dynamic Scheduling**: Incorporates a high-performance scheduler to optimize the `computation-communication overlap`, enhancing GPU utilization and preventing training delays.\n- **Capability for Very Large Models**: Enables the fine-tuning of extraordinarily large models, such as those with over `175 billion parameters`, on single GPUs with as little as `18GB` of memory, previously impossible with traditional methods.\n- **Empirical Validation**: ZO2 has demonstrated through rigorous testing that it can efficiently fine-tune massive models `without extra time costs or accuracy losses`, confirming its effectiveness for large-scale model training.\n\n## ⚙️ Installation\n\nWe offer two installation options, and you only need to use one of them to install ZO2:\n\n1. To experiment with our examples, tutorials, or tests, follow these steps to set up the ZO2 environment:\n\n```shell\n  git clone https://github.com/liangyuwang/zo2.git\n  cd zo2/\n  conda env create -f env.yml\n  conda activate zo2\n```\n\n2. If you want to use ZO2 as a package in your own code, you can install it directly in your Python environment.\n\n    Before installing the ZO2 package, ensure you have the required dependencies:\n\n    - [PyTorch](https://pytorch.org/get-started/locally/) >= 2.4.0, CUDA >= 12.1\n\n    Once the dependencies are installed, you can install the ZO2 package using pip:\n\n```shell\n  pip install git+https://github.com/liangyuwang/zo2.git\n```\n\n## 🛠️ Usage\n\nWe utilize the [OPT](https://arxiv.org/abs/2205.01068) models and [MeZO-SGD](https://arxiv.org/abs/2305.17333) as examples. For additional information, please refer to the section on [Supported Models and ZO methods](#-supported-models-zo-methods-and-tasks-support).\n\n### 1. Using [MeZO-Runner](example/mezo_runner/) to Evaluate Fine-tuning Tasks\n\nBefore running the following commands, please ensure that you have cloned the entire project. If you [installed](#️installation) ZO2 using option 2, you will need to run \"git clone https://github.com/liangyuwang/zo2.git\" to obtain the complete project, then navigate to the zo2 folder by \"cd zo2\".\n\n```shell\ncd example/mezo_runner/\nexport CUDA_VISIBLE_DEVICES=0\nMODEL=facebook/opt-2.7b TASK=SST2 MODE=ft LR=1e-7 EPS=1e-3 STEPS=20000 EVAL_STEPS=4000 bash mezo.sh\n```\n\n### 2. Fine-Tuning HF Models with ZOTrainer / ZOSFTTrainer [[Trainer](./tutorial/huggingface.ipynb)]\n\n```python\nfrom zo2 import ZOConfig, zo_hf_init\nfrom zo2.trainer.hf_transformers import ZOTrainer\nfrom transformers import TrainingArguments\n\n# Model and optimizer init\nzo_config = ZOConfig(method=\"mezo-sgd\", zo2=True, offloading_device='cpu', working_device='cuda', lr=1e-5)\nwith zo_hf_init(zo_config):\n    from transformers import OPTForCausalLM\n    model = OPTForCausalLM.from_pretrained(\"facebook/opt-125m\")\n    model.zo_init(zo_config)\n\ntraining_args = TrainingArguments(\"test-trainer\")\n\ntrainer = ZOTrainer(\n    model,\n    args = training_args,\n    train_dataset=...,   # get training dataset\n    eval_dataset=...,    # get eval dataset\n    data_collator=...,   # get data_collator\n    tokenizer=...,       # use suitable tokenizer\n    ...\n)\n\ntrainer.train()\n```\n\n### 3. Train HF Models with Custom Training Loop [[demo](./tutorial/demo.ipynb)]\n\n```python\nfrom zo2 import ZOConfig, zo_hf_init\n\n# Model and optimizer init\nzo_config = ZOConfig(method=\"mezo-sgd\", zo2=True, offloading_device='cpu', working_device='cuda', lr=1e-5)\nwith zo_hf_init(zo_config):\n    from transformers import OPTForCausalLM\n    model = OPTForCausalLM.from_pretrained(\"facebook/opt-125m\")\n    model.zo_init(zo_config)\n\n# Training loop\nfor i in range(max_training_step):\n    # Train\n    training_input_ids, training_labels = ...   # get training data batch\n    model.zo_train()\n    loss = model(input_ids=training_input_ids, labels=training_labels)\n    # Evaluate\n    eval_input_ids, eval_labels = ...   # get eval data batch\n    model.zo_eval()     \n    output = model(input_ids=eval_input_ids, labels=eval_labels)\n```\n\n## ✨ Tutorial\n\nPlease refer to [tutorial](./tutorial/).\n\n## 🤖 Supported Models, ZO methods, and Tasks\n\n- **Models**:\n\n  * [NanoGPT](https://github.com/karpathy/build-nanogpt/blob/master/train_gpt2.py)   (mainly for idea evaluation)\n  * [Transformers](https://github.com/huggingface/transformers): [OPT](https://arxiv.org/abs/2205.01068)\n- **ZO methods**:\n\n  * [MeZO-SGD](https://arxiv.org/abs/2305.17333)\n- **Tasks**: Please refer to [MeZO-Runner](example/mezo_runner/)\n\n## 🧪 Test\n\nPlease refer to [test](./test/).\n\n## 🧭 Roadmap\n\n- [ ] Support more models like LLaMA and DeepSeek\n- [ ] Support more ZO methods\n- [ ] Support more offloading strategies (Disk offloading)\n\n## 🚶 Contributing\n\nFeel free to submit issues and pull requests to improve the project!\n\n## 📲 Contact\n\n* Liangyu Wang: liangyu.wang@kaust.edu.sa\n\n## 📖 BibTeX\n\n```\n@article{wang2025zo2,\n  title={ZO2: Scalable Zeroth-Order Fine-Tuning for Extremely Large Language Models with Limited GPU Memory},\n  author={Wang, Liangyu and Ren, Jie and Xu, Hang and Wang, Junxiao and Xie, Huanyi and Keyes, David E and Wang, Di},\n  journal={arXiv preprint arXiv:2503.12668},\n  year={2025}\n}\n```\n"
  },
  {
    "path": "env.yml",
    "content": "name: zo2\nchannels:\n  - pytorch\n  - nvidia\n  - defaults\ndependencies:\n  - _libgcc_mutex=0.1=main\n  - _openmp_mutex=5.1=1_gnu\n  - blas=1.0=mkl\n  - brotli-python=1.0.9=py311h6a678d5_8\n  - bzip2=1.0.8=h5eee18b_6\n  - ca-certificates=2024.7.2=h06a4308_0\n  - certifi=2024.7.4=py311h06a4308_0\n  - charset-normalizer=3.3.2=pyhd3eb1b0_0\n  - cuda-cudart=12.1.105=0\n  - cuda-cupti=12.1.105=0\n  - cuda-libraries=12.1.0=0\n  - cuda-nvrtc=12.1.105=0\n  - cuda-nvtx=12.1.105=0\n  - cuda-opencl=12.6.37=0\n  - cuda-runtime=12.1.0=0\n  - cuda-version=12.6=3\n  - ffmpeg=4.3=hf484d3e_0\n  - filelock=3.13.1=py311h06a4308_0\n  - freetype=2.12.1=h4a9f257_0\n  - gmp=6.2.1=h295c915_3\n  - gmpy2=2.1.2=py311hc9b5ff0_0\n  - gnutls=3.6.15=he1e5248_0\n  - idna=3.7=py311h06a4308_0\n  - intel-openmp=2023.1.0=hdb19cb5_46306\n  - jinja2=3.1.4=py311h06a4308_0\n  - jpeg=9e=h5eee18b_3\n  - lame=3.100=h7b6447c_0\n  - lcms2=2.12=h3be6417_0\n  - ld_impl_linux-64=2.38=h1181459_1\n  - lerc=3.0=h295c915_0\n  - libcublas=12.1.0.26=0\n  - libcufft=11.0.2.4=0\n  - libcufile=1.11.0.15=0\n  - libcurand=10.3.7.37=0\n  - libcusolver=11.4.4.55=0\n  - libcusparse=12.0.2.55=0\n  - libdeflate=1.17=h5eee18b_1\n  - libffi=3.4.4=h6a678d5_1\n  - libgcc-ng=11.2.0=h1234567_1\n  - libgomp=11.2.0=h1234567_1\n  - libiconv=1.16=h5eee18b_3\n  - libidn2=2.3.4=h5eee18b_0\n  - libjpeg-turbo=2.0.0=h9bf148f_0\n  - libnpp=12.0.2.50=0\n  - libnvjitlink=12.1.105=0\n  - libnvjpeg=12.1.1.14=0\n  - libpng=1.6.39=h5eee18b_0\n  - libstdcxx-ng=11.2.0=h1234567_1\n  - libtasn1=4.19.0=h5eee18b_0\n  - libtiff=4.5.1=h6a678d5_0\n  - libunistring=0.9.10=h27cfd23_0\n  - libuuid=1.41.5=h5eee18b_0\n  - libwebp-base=1.3.2=h5eee18b_0\n  - llvm-openmp=14.0.6=h9e868ea_0\n  - lz4-c=1.9.4=h6a678d5_1\n  - markupsafe=2.1.3=py311h5eee18b_0\n  - mkl=2023.1.0=h213fc3f_46344\n  - mkl-service=2.4.0=py311h5eee18b_1\n  - mkl_fft=1.3.8=py311h5eee18b_0\n  - mkl_random=1.2.4=py311hdb19cb5_0\n  - mpc=1.1.0=h10f8cd9_1\n  - mpfr=4.0.2=hb69a4c5_1\n  - mpmath=1.3.0=py311h06a4308_0\n  - ncurses=6.4=h6a678d5_0\n  - nettle=3.7.3=hbbd107a_1\n  - networkx=3.3=py311h06a4308_0\n  - numpy=1.26.4=py311h08b1b3b_0\n  - numpy-base=1.26.4=py311hf175353_0\n  - openh264=2.1.1=h4ff587b_0\n  - openjpeg=2.5.2=he7f1fd0_0\n  - openssl=3.0.14=h5eee18b_0\n  - pillow=10.4.0=py311h5eee18b_0\n  - pip=24.0=py311h06a4308_0\n  - pysocks=1.7.1=py311h06a4308_0\n  - python=3.11.9=h955ad1f_0\n  - pytorch=2.4.0=py3.11_cuda12.1_cudnn9.1.0_0\n  - pytorch-cuda=12.1=ha16c6d3_5\n  - pytorch-mutex=1.0=cuda\n  - pyyaml=6.0.1=py311h5eee18b_0\n  - readline=8.2=h5eee18b_0\n  - requests=2.32.3=py311h06a4308_0\n  - setuptools=72.1.0=py311h06a4308_0\n  - sqlite=3.45.3=h5eee18b_0\n  - sympy=1.12=py311h06a4308_0\n  - tbb=2021.8.0=hdb19cb5_0\n  - tk=8.6.14=h39e8969_0\n  - torchaudio=2.4.0=py311_cu121\n  - torchtriton=3.0.0=py311\n  - torchvision=0.19.0=py311_cu121\n  - typing_extensions=4.11.0=py311h06a4308_0\n  - urllib3=2.2.2=py311h06a4308_0\n  - wheel=0.43.0=py311h06a4308_0\n  - xz=5.4.6=h5eee18b_1\n  - yaml=0.2.5=h7b6447c_0\n  - zlib=1.2.13=h5eee18b_1\n  - zstd=1.5.5=hc292b87_2\n  - pip:\n      - accelerate==1.6.0\n      - aiohappyeyeballs==2.3.5\n      - aiohttp==3.10.3\n      - aiosignal==1.3.1\n      - attrs==24.2.0\n      - datasets==3.5.1\n      - dill==0.3.8\n      - frozenlist==1.4.1\n      - fsspec==2024.5.0\n      - huggingface-hub==0.30.2\n      - joblib==1.4.2\n      - markdown-it-py==3.0.0\n      - mdurl==0.1.2\n      - multidict==6.0.5\n      - multiprocess==0.70.16\n      - nvidia-ml-py==12.570.86\n      - opt-einsum==3.3.0\n      - packaging==24.1\n      - pandas==2.2.2\n      - psutil==6.0.0\n      - pyarrow==17.0.0\n      - pyarrow-hotfix==0.6\n      - pygments==2.19.1\n      - python-dateutil==2.9.0.post0\n      - pytz==2024.1\n      - regex==2024.7.24\n      - rich==14.0.0\n      - safetensors==0.5.3\n      - scikit-learn==1.5.1\n      - scipy==1.14.0\n      - six==1.16.0\n      - threadpoolctl==3.5.0\n      - tokenizers==0.21.1\n      - tqdm==4.66.5\n      - transformers==4.51.3\n      - trl==0.17.0\n      - tzdata==2024.1\n      - xxhash==3.4.1\n      - yarl==1.9.4\n"
  },
  {
    "path": "example/demo/train_zo2_with_hf_opt.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport sys\nsys.path.append(\"../zo2\")\n\nimport argparse\nfrom tqdm.auto import tqdm\nimport torch\nfrom transformers import AutoTokenizer\nfrom zo2 import (\n    ZOConfig,\n    zo_hf_init,\n)\nfrom zo2.utils.utils import seed_everything\n\n# Hyper\nargs = argparse.ArgumentParser()\nargs.add_argument(\"--zo_method\", type=str, default=\"zo2\")\nargs.add_argument(\"--eval\", action=\"store_true\")\nargs.add_argument(\"--model_name\", type=str, default=\"facebook/opt-2.7b\")\nargs.add_argument(\"--verbose\", action=\"store_true\")\nargs.add_argument(\"--max_steps\", type=int, default=100)\nargs.add_argument(\"--lr\", type=float, default=1e-5)\nargs.add_argument(\"--weight_decay\", type=float, default=1e-1)\nargs.add_argument(\"--zo_eps\", type=float, default=1e-3)\nargs.add_argument(\"--seed\", type=int, default=42)\nargs.add_argument(\"--offloading_device\", type=str, default=\"cpu\")\nargs.add_argument(\"--working_device\", type=str, default=\"cuda:0\")\n# For inference\nargs.add_argument(\"--use_cache\", action=\"store_true\")\nargs.add_argument(\"--max_new_tokens\", type=int, default=50)\nargs.add_argument(\"--temperature\", type=float, default=1.0)\nargs = args.parse_args()\n\nseed_everything(args.seed)\n\n# ZO steps\nzo_config = ZOConfig(\n    method=\"mezo-sgd\", \n    zo2=args.zo_method==\"zo2\", \n    lr=args.lr,\n    weight_decay=args.weight_decay,\n    eps=args.zo_eps,\n    offloading_device=args.offloading_device,\n    working_device=args.working_device,\n)\n\n# Load ZO model\nwith zo_hf_init(zo_config):\n    from transformers import OPTForCausalLM\n    model = OPTForCausalLM.from_pretrained(args.model_name)\n    model.zo_init(zo_config)\nif args.zo_method != \"zo2\": model = model.to(args.working_device)\nprint(f\"Check if zo2 init correctly: {hasattr(model, 'zo_training')}\")\n\n# Prepare some data\ndataset = \"\"\"\n    What is ZO2? \n    ZO2 is an innovative framework specifically designed to enhance the fine-tuning of large language models (LLMs) using zeroth-order (ZO) optimization techniques and advanced offloading technologies. \n    This framework is particularly tailored for setups with limited GPU memory, enabling the fine-tuning of models that were previously unmanageable due to hardware constraints. \n    As the scale of Large Language Models (LLMs) continues to grow, reaching parameter counts in the hundreds of billions, managing GPU memory resources effectively becomes crucial. \n    Efficient GPU memory management is crucial not only because it directly influences model performance and training speed, but also because GPU memory is both expensive and limited in quantity. \n    However, this creates a significant challenge in handling ever-larger models within the physical constraints of current hardware technologies. \n    CPU offloading has become a crucial technique for overcoming this challenge. \n    It involves transferring computations and data from the GPU to the CPU, specifically targeting data or parameters that are less frequently accessed. \n    By offloading these inactive tensors of the neural network, CPU offloading effectively alleviates the memory and computational pressures on GPUs. \n    While CPU offloading has been commonly applied in inference to manage memory-intensive tasks, its application in training, especially fine-tuning, remains less explored. \n    Recently, some works have tried to introduce CPU offloading into LLM training. \n    However, they are typically constrained by the capabilities of first-order optimizers such as SGD and Adaptive Moment Estimation (AdamW), and limited GPU memory, restricting large-scale model scalability on single GPU systems. \n    Using first-order optimizers introduces inefficiencies in CPU offloading: Multiple communication operations during the training of LLMs necessitate offloading the same data twice—once for each pass. \n    This redundancy not only doubles the communication volume between the CPU and GPU but also introduces significant latency due to repetitive data transfers. \n    Furthermore, both parameters and activations are required in the backward pass to complete gradient computations. \n    This means that parameters and activation values must be offloaded during each forward pass and re-uploaded to the GPU for the backward pass, increasing the volume of data transferred, which severely impacts training throughput. \n    On the other hand, zeroth-order (ZO) methods offer a novel approach to fine-tuning LLMs. \n    These methods utilize dual forward passes to estimate parameter gradients and subsequently update parameters. \n    This approach eliminates the traditional reliance on backward passes, thereby streamlining the training process by significantly reducing the number of computational steps required. \n    Based on these observations, we conjecture that ZO's architecture is particularly well-suited for CPU offloading strategies. \n    By eliminating backward passes and the need to store activation values, it can significantly reduce GPU memory demands through efficient parameter offloading. \n    However, despite these advantages, ZO training via CPU offloading introduces new challenges, particularly in the realm of CPU-to-GPU communication. \n    Transferring parameters between the CPU and GPU, which is crucial for maintaining gradient computation and model updates, becomes a critical bottleneck. \n    Although ZO methods inherently extend computation times because of the dual forward passes, potentially allowing for better overlap between computation and communication, there remain significant inefficiencies. \n    The necessity to upload parameters to the GPU for upcoming computations introduces a large volume of communications. To tackle the inefficiencies highlighted, we introduce ZO2, a novel framework specifically designed for ZO fine-tuning in LLMs with CPU offloading. \n    This framework utilizes the unique dual forward pass architecture of ZO methods to optimize interactions between CPU and GPU, significantly enhancing both computational and communication efficiency. \n    By building a high-performance dynamic scheduler, ZO2 achieves substantial overlaps in communication and computation. \n    These innovations make it feasible to fine-tune extremely large models, such as the OPT-175B, with over 175 billion parameters, on a single GPU equipped with just 18GB of memory usage—a capability previously unattainable with conventional methods. \n    Additionally, our efficient framework operates without any extra time cost and decreases in accuracy compared to standard ZO methodologies.\"\"\"\ntokenizer = AutoTokenizer.from_pretrained(args.model_name)\ndata_batch = tokenizer(dataset, add_special_tokens=True, return_tensors='pt').input_ids.to(args.working_device)\nT = min(data_batch.shape[1] - 1, model.config.max_position_embeddings)\nprint(f\"Fine-tuning model {args.model_name} with {T} tokens dataset: \\n{dataset}\")\n\n# Training loop\nfor i in tqdm(range(args.max_steps)):\n    # train\n    model.zo_train()\n    loss = model(input_ids=data_batch, labels=data_batch)\n\n    # eval\n    if args.eval:\n        if i==0:\n            tqdm.write(\"Warning: please notice that ZO2 does not optimize the evaluation, so it may be very slow.\")\n        model.zo_eval()\n        output = model(input_ids=data_batch, labels=data_batch)\n        res = \"Iteration {}, train loss: {}, projected grad: {}, eval loss: {}\"\n        tqdm.write(res.format(i, loss, model.opt.projected_grad, output[\"loss\"]))\n    else:\n        res = \"Iteration {}, train loss: {}, projected grad: {}\"\n        tqdm.write(res.format(i, loss, model.opt.projected_grad))\n\n# inference\nprint(\"Doing inference...\")\nprint(\"Warning: please notice that ZO2 does not optimize the inference, so it may be very slow.\")\nmodel.zo_eval()\nprompt = \"What is ZO2 and how ZO2 enhance the fine-tuning of large language models?\"\ninputs = tokenizer(prompt, return_tensors='pt').to(args.working_device)\ninputs = {\"input_ids\": inputs.input_ids}\nfor _ in tqdm(range(args.max_new_tokens)):\n    outputs = model(**inputs, return_dict=True)\n    next_token_logits = outputs.logits[:, -1, :]\n    if args.temperature == 1.0:\n        next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)\n    else:\n        scaled_logits = next_token_logits / args.temperature\n        probs = torch.nn.functional.softmax(scaled_logits, dim=-1)\n        next_token = torch.multinomial(probs, num_samples=1)\n    inputs = torch.cat([inputs[\"input_ids\"], next_token], dim=-1)\n    generated_text = tokenizer.decode(inputs[0])\n    inputs = {\"input_ids\": inputs}\nprint(f\"Question: {prompt}\")\nprint(f\"Response: {generated_text[len(prompt)+4:]}...\")"
  },
  {
    "path": "example/mezo_runner/README.md",
    "content": "# Example: Apply MeZO on LLMs\n\nModified from [MeZO](https://github.com/princeton-nlp/MeZO/blob/main/large_models/README.md)\n\n## Usage\n\nUse `run.py` for all functions (zero-shot/MeZO):\n\n```bash\npython run.py {ARGUMENTS}\n```\n\nPlease read [run.py](./run.py) for a complete list of arguments. We introduce some of the most important ones below.\n\n* `--num_train`: Number of training examples.\n* `--num_dev`: Number of validation examples.\n* `--num_test`: Number of testing examples.\n* `--model_name`: HuggingFace model name or path.\n* `--task_name`: Task name.\n* `--trainer`: can be `none` (zero-shot) or `zo` (MeZO).\n* `--train_as_classification`: turn this on for classification tasks (Cross Entropy over likelihood of each class' label words). Otherwise it is LM-style teacher forcing.\n* `--zo_eps`: MeZO hyperparameter epsilon.\n* `--zo_method`: choose zeroth-order methods.\n* `--zo_mode`: can be `zo` (on device) or `zo2` (offloading).\n* `--offloading_device`: offloading device.\n* `--working_device`: main working device.\n\nExample:\n\n1. MeZO (full-parameter fine-tuning)\n```bash\n# You can adjust the following model size and other hyperparameters.\n# OPT-2.7B\nMODEL=facebook/opt-2.7b TASK=SST2 MODE=ft LR=1e-7 EPS=1e-3 STEPS=20000 EVAL_STEPS=4000 bash mezo.sh\n# Qwen3-1.7B\nMODEL=Qwen/Qwen3-1.7B TASK=SST2 MODE=ft LR=1e-7 EPS=1e-3 STEPS=20000 EVAL_STEPS=4000 bash mezo.sh\n```\n\n## Supported Tasks (See [tasks.py](./tasks.py))\n\n- **SST2**\n- **Copa**\n- **BoolQ**\n- **MultiRC**\n- **CB**\n- **WIC**\n- **WSC**\n- **ReCoRD**\n- **RTE**\n- **SQuAD**\n- **DROP**\n"
  },
  {
    "path": "example/mezo_runner/metrics.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport numpy as np\nimport collections\nimport re\nimport string\nfrom collections import Counter\n\ndef normalize_answer(s):\n    \"\"\"Lower text and remove punctuation, articles and extra whitespace.\"\"\"\n\n    def remove_articles(text):\n        return re.sub(r'\\b(a|an|the)\\b', ' ', text)\n\n    def white_space_fix(text):\n        return ' '.join(text.split())\n\n    def remove_punc(text):\n        exclude = set(string.punctuation)\n        return ''.join(ch for ch in text if ch not in exclude)\n\n    def lower(text):\n        return text.lower()\n\n    return white_space_fix(remove_articles(remove_punc(lower(s))))\n\n\ndef calculate_metric(predictions, metric_name):\n    if metric_name == \"accuracy\":\n        if isinstance(predictions[0].correct_candidate, list):\n            return np.mean([pred.predicted_candidate in pred.correct_candidate for pred in predictions])\n        else:\n            return np.mean([pred.correct_candidate == pred.predicted_candidate for pred in predictions])\n    elif metric_name == \"em\":\n        # For question answering\n        return np.mean([any([normalize_answer(ans) == normalize_answer(pred.predicted_candidate) for ans in pred.correct_candidate]) for pred in predictions])\n    elif metric_name == \"f1\":\n        # For question answering\n        f1 = []\n        for pred in predictions:\n            all_f1s = []\n            if pred.correct_candidate[0] == \"CANNOTANSWER\" or pred.correct_candidate[0] == \"no answer\":\n                f1.append(int(normalize_answer(pred.correct_candidate[0]) == normalize_answer(pred.predicted_candidate)))\n            else:\n                for ans in pred.correct_candidate:\n                    prediction_tokens = normalize_answer(pred.predicted_candidate).split()\n                    ground_truth_tokens = normalize_answer(ans).split()\n                    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)\n                    num_same = sum(common.values())\n                    if num_same == 0:\n                        all_f1s.append(0)\n                    else:\n                        precision = 1.0 * num_same / len(prediction_tokens)\n                        recall = 1.0 * num_same / len(ground_truth_tokens)\n                        all_f1s.append((2 * precision * recall) / (precision + recall))\n                f1.append(max(all_f1s))\n\n        return np.mean(f1)\n\n\ndef f1(pred, gold):\n    \"\"\"\n    This separate F1 function is used as non-differentiable metric for SQuAD\n    \"\"\"\n    if gold[0] == \"CANNOTANSWER\" or gold[0] == \"no answer\":\n        return int(normalize_answer(gold[0]) == normalize_answer(pred))\n    else:\n        all_f1s = []\n        for ans in gold:\n            prediction_tokens = normalize_answer(pred).split()\n            ground_truth_tokens = normalize_answer(ans).split()\n            common = Counter(prediction_tokens) & Counter(ground_truth_tokens)\n            num_same = sum(common.values())\n            if num_same == 0:\n                all_f1s.append(0)\n            else:\n                precision = 1.0 * num_same / len(prediction_tokens)\n                recall = 1.0 * num_same / len(ground_truth_tokens)\n                all_f1s.append((2 * precision * recall) / (precision + recall))\n        return np.max(all_f1s)"
  },
  {
    "path": "example/mezo_runner/mezo.sh",
    "content": "MODEL=${MODEL:-facebook/opt-1.3b}\nMODEL_NAME=(${MODEL//\\// })\nMODEL_NAME=\"${MODEL_NAME[-1]}\"\n\nBS=${BS:-16}\nLR=${LR:-1e-5}\nEPS=${EPS:-1e-3}\nSEED=${SEED:-0}\nTRAIN=${TRAIN:-1000}\nDEV=${DEV:-500}\nEVAL=${EVAL:-1000}\nSTEPS=${STEPS:-20000}\nEVAL_STEPS=${EVAL_STEPS:-4000}\n\nMODE=${MODE:-ft}\nEXTRA_ARGS=\"\"\nif [ \"$MODE\" == \"prefix\" ]; then\n    EXTRA_ARGS=\"--prefix_tuning --num_prefix 5 --no_reparam --prefix_init_by_real_act\"\nelif [ \"$MODE\" == \"lora\" ]; then\n    EXTRA_ARGS=\"--lora\"\nfi\nTAG=mezo-$MODE-$STEPS-$BS-$LR-$EPS-$SEED\n\nTASK_ARGS=\"\"\ncase $TASK in\n    # For Copa, ReCoRD, SQuAD, DROP, we set --train_as_classification False; for others, set this flag to True\n    CB) # It has <1000 training examples. Only use 100 for dev\n        DEV=100\n        ;;\n    Copa) # It has <1000 training examples. Only use 100 for dev\n        DEV=100\n        TASK_ARGS=\"--train_as_classification False\"\n        ;;\n    ReCoRD) \n        TASK_ARGS=\"--train_as_classification False\"\n        ;;\n    DROP) \n        TASK_ARGS=\"--train_as_classification False\"\n        ;;\n    SQuAD)\n        TASK_ARGS=\"--train_as_classification False\"\n        ;;\nesac\n\necho $TAG\necho \"BS: $BS\"\necho \"LR: $LR\"\necho \"EPS: $EPS\"\necho \"SEED: $SEED\"\necho \"TRAIN/EVAL STEPS: $STEPS/$EVAL_STEPS\"\necho \"MODE: $MODE\"\necho \"Extra args: $EXTRA_ARGS $TASK_ARGS\"\n\npython run.py \\\n    --model_name $MODEL \\\n    --task_name $TASK \\\n    --output_dir result/$TASK-${MODEL_NAME}-$TAG --tag $TAG --train_set_seed $SEED --num_train $TRAIN --num_dev $DEV --num_eval $EVAL --logging_steps 10 \\\n    --max_steps $STEPS \\\n    --trainer zo --load_float16 \\\n    --learning_rate $LR --zo_eps $EPS --per_device_train_batch_size $BS --lr_scheduler_type \"constant\" \\\n    --load_best_model_at_end --eval_strategy steps --save_strategy steps --save_total_limit 1 \\\n    --eval_steps $EVAL_STEPS --save_steps $EVAL_STEPS \\\n    --train_as_classification \\\n    $EXTRA_ARGS \\\n    $TASK_ARGS \\\n    \"$@\"\n"
  },
  {
    "path": "example/mezo_runner/run.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n\"\"\"\nModified from https://github.com/princeton-nlp/MeZO/blob/main/large_models/run.py\n\"\"\"\n\nimport sys\nsys.path.append(\"../../../zo2\")\n\nimport logging\n\nlogging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n\nimport argparse\nimport time\nimport tasks\nfrom transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, Trainer, HfArgumentParser, Trainer, TrainingArguments, DataCollatorWithPadding, DataCollatorForTokenClassification\nfrom typing import Union, Optional\nimport torch\nfrom torch.nn.parameter import Parameter\nimport numpy as np\nfrom dataclasses import dataclass, is_dataclass, asdict\nfrom tqdm import tqdm\nfrom tasks import get_task\nimport json\nimport torch.nn.functional as F\nfrom torch.utils.data import Dataset\nfrom torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP\nfrom metrics import calculate_metric\nfrom utils import *\nimport random\n\nfrom zo2.trainer.hf_transformers.trainer import ZOTrainer\nfrom zo2 import zo_hf_init, ZOConfig\n\n@dataclass\nclass OurArguments(TrainingArguments):\n    # dataset and sampling strategy\n    task_name: str = \"SST2\" # task name should match the string before Dataset in the Dataset class name. We support the following task_name: SST2, RTE, CB, BoolQ, WSC, WIC, MultiRC, Copa, ReCoRD, SQuAD, DROP\n\n    # Number of examples\n    num_train: int = 0 # ICL mode: number of demonstrations; training mode: number of training samples\n    num_dev: int = None # (only enabled with training) number of development samples\n    num_eval: int = None # number of evaluation samples\n    num_train_sets: int = None # how many sets of training samples/demos to sample; if None and train_set_seed is None, then we will sample one set for each evaluation sample\n    train_set_seed: int = None # designated seed to sample training samples/demos\n    result_file: str = None # file name for saving performance; if None, then use the task name, model name, and config\n\n    # Model loading\n    model_name: str = \"facebook/opt-125m\" # HuggingFace model name\n    load_float16: bool = False # load model parameters as float16\n    load_bfloat16: bool = False # load model parameters as bfloat16\n    load_int8: bool = False # load model parameters as int8\n    max_length: int = 2048 # max length the model can take\n    no_auto_device: bool = False # do not load model by auto device; should turn this on when using FSDP\n\n    # Calibration\n    sfc: bool = False # whether to use SFC calibration\n    icl_sfc: bool = False # whether to use SFC calibration for ICL samples\n\n    # Training\n    trainer: str = \"none\" \n    ## options\n    ## - none: no training -- for zero-shot or in-context learning (ICL)\n    ## - regular: regular huggingface trainer -- for fine-tuning\n    ## - zo: zeroth-order (MeZO) training\n    only_train_option: bool = True # whether to only train the option part of the input\n    train_as_classification: bool = False # take the log likelihood of all options and train as classification \n\n    # MeZO\n    zo_eps: float = 1e-3 # eps in MeZO\n\n    # Prefix tuning\n    prefix_tuning: bool = False # whether to use prefix tuning\n    num_prefix: int = 5 # number of prefixes to use\n    no_reparam: bool = True # do not use reparameterization trick\n    prefix_init_by_real_act: bool = True # initialize prefix by real activations of random words\n\n    # LoRA\n    lora: bool = False # whether to use LoRA\n    lora_alpha: int = 16 # alpha in LoRA\n    lora_r: int = 8 # r in LoRA\n\n    # Generation\n    sampling: bool = False # whether to use sampling\n    temperature: float = 1.0 # temperature for generation\n    num_beams: int = 1 # number of beams for generation\n    top_k: int = None # top-k for generation\n    top_p: float = 0.95 # top-p for generation\n    max_new_tokens: int = 50 # max number of new tokens to generate\n    eos_token: str = \"\\n\" # end of sentence token\n\n    # Saving\n    save_model: bool = False # whether to save the model\n    no_eval: bool = False # whether to skip evaluation\n    tag: str = \"\" # saving tag\n\n    # Linear probing\n    linear_probing: bool = False # whether to do linear probing\n    lp_early_stopping: bool = False # whether to do early stopping in linear probing\n    head_tuning: bool = False # head tuning: only tune the LM head\n\n    # Untie emb/lm_head weights\n    untie_emb: bool = False # untie the embeddings and LM head\n\n    # Display\n    verbose: bool = False # verbose output\n\n    # Non-diff objective\n    non_diff: bool = False # use non-differentiable objective (only support F1 for SQuAD for now)\n\n    # Auto saving when interrupted\n    save_on_interrupt: bool = False # save model when interrupted (useful for long training)\n\n    # ZO2 added -> ZO2 configs\n    zo_method: str = \"mezo-sgd\"\n    zo_mode: str = \"zo2\"\n    offloading_device: str = \"cpu\"\n    working_device: str = \"cuda:0\"\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser = HfArgumentParser(OurArguments)\n    args = parser.parse_args_into_dataclasses()[0]\n    print(args)\n    return args\n\n\ndef set_seed(seed: int):\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n\n\nclass Framework:\n\n    def __init__(self, args, task):\n        self.args = args\n        self.task = task\n        self.model, self.tokenizer = self.load_model()\n\n\n    def load_model(self):\n        \"\"\"\n        Load HuggingFace models\n        \"\"\"\n        with count_time(\"Loading model with FP%d\" % (16 if self.args.load_float16 else 32)):\n            free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3)\n            config = AutoConfig.from_pretrained(self.args.model_name)\n            if self.args.untie_emb:\n                # Untie embeddings/LM head\n                logger.warn(\"Untie embeddings and LM head\")\n                config.tie_word_embeddings = False\n            # if self.args.head_tuning:\n            #     # Head tuning\n            #     from ht_opt import OPTForCausalLM\n            #     model = OPTForCausalLM.from_pretrained(\n            #         self.args.model_name,\n            #         config=config,\n            #     )\n            # elif self.args.no_auto_device:\n            #     # No auto device (use for FSDP)\n            #     model = AutoModelForCausalLM.from_pretrained(\n            #         self.args.model_name,\n            #         config=config,\n            #     )\n            # else:\n            #     # Auto device loading\n            #     torch_dtype = torch.float32\n            #     if self.args.load_float16:\n            #         torch_dtype = torch.float16\n            #     elif self.args.load_bfloat16:\n            #         torch_dtype = torch.bfloat16\n            #     model = AutoModelForCausalLM.from_pretrained(\n            #         self.args.model_name,\n            #         config=config,\n            #         device_map='auto',\n            #         torch_dtype=torch_dtype,\n            #         max_memory={i: f'{free_in_GB-5}GB' for i in range(torch.cuda.device_count())},\n            #         load_in_8bit=self.args.load_int8,\n            #     )\n\n            # ZO2 added -> init ZO2 model\n            torch_dtype = torch.float32\n            if self.args.load_float16:\n                torch_dtype = torch.float16\n            elif self.args.load_bfloat16:\n                torch_dtype = torch.bfloat16\n            # Set up ZO configuration\n            self.zo_config = ZOConfig(\n                method=\"mezo-sgd\",\n                zo2=(self.args.zo_mode == \"zo2\"),\n                lr=self.args.learning_rate,\n                weight_decay=self.args.weight_decay,\n                eps=self.args.zo_eps,\n                offloading_device=self.args.offloading_device,\n                working_device=self.args.working_device,\n            )\n            # Initialize model within zo_hf_init context\n            with zo_hf_init(self.zo_config):\n                if \"opt\" in self.args.model_name:\n                    from transformers import OPTForCausalLM\n                    model = OPTForCausalLM.from_pretrained(\n                        self.args.model_name, \n                        config=config,\n                        torch_dtype=torch_dtype,\n                        max_memory={i: f'{free_in_GB-5}GB' for i in range(torch.cuda.device_count())},\n                        load_in_8bit=self.args.load_int8,\n                    )\n                elif \"Qwen3\" in self.args.model_name:\n                    from transformers import Qwen3ForCausalLM\n                    model = Qwen3ForCausalLM.from_pretrained(\n                        self.args.model_name,\n                        config=config,\n                        torch_dtype=torch_dtype,\n                        max_memory={i: f'{free_in_GB-5}GB' for i in range(torch.cuda.device_count())},\n                        load_in_8bit=self.args.load_int8,\n                    )\n                model.zo_init(self.zo_config)\n            logger.info(f\"Check if zo2 init correctly: {hasattr(model, 'zo_training')}\")\n            # If using a method other than zo2, move model to working device\n            if self.args.zo_method != \"zo2\":\n                model = model.to(self.args.working_device)\n\n            model.eval()\n\n        # Load tokenizer\n        tokenizer = AutoTokenizer.from_pretrained(self.args.model_name, use_fast=False)\n\n        # HF tokenizer bug fix\n        if \"opt\" in self.args.model_name:\n            tokenizer.bos_token_id = 0\n        \n        if \"llama\" in self.args.model_name:\n            # LLaMA padding token\n            tokenizer.pad_token_id = 0 # technically <unk>\n\n        if \"Qwen3\" in self.args.model_name:\n            # LLaMA padding token\n            tokenizer.add_bos_token = False\n\n        # Prefix tuning/LoRA\n        if self.args.prefix_tuning:\n            # from prefix import PrefixTuning\n            # PrefixTuning(model, num_prefix=self.args.num_prefix, reparam=not self.args.no_reparam, float16=self.args.load_float16, init_by_real_act=self.args.prefix_init_by_real_act)\n            raise NotImplementedError\n        if self.args.lora:\n            # from lora import LoRA\n            # LoRA(model, r=self.args.lora_r, alpha=self.args.lora_alpha, float16=self.args.load_float16)\n            raise NotImplementedError\n\n        if self.args.head_tuning:\n            # if model.config.model_type == \"opt\":\n            #     head_name = \"lm_head\" if self.args.untie_emb else \"embed_tokens\"\n            # else:\n            #     raise NotImplementedError\n            # for n, p in model.named_parameters():\n            #     if head_name not in n:\n            #         p.requires_grad = False\n            #     else:\n            #         logger.info(f\"Only tuning {n}\")\n            raise NotImplementedError\n\n        return model, tokenizer\n\n\n    def forward(self, input_ids, option_len=None, generation=False):\n        \"\"\"\n        Given input_ids and the length of the option, return the log-likelihood of each token in the option.\n        For generation tasks, return the generated text.\n        This function is only for inference\n        \"\"\"\n        input_ids = torch.tensor([input_ids]).to(self.model.device)\n\n        if generation:\n            args = self.args\n            # Autoregressive generation\n            outputs = self.model.generate(\n                input_ids, do_sample=args.sampling, temperature=args.temperature, \n                num_beams=args.num_beams, top_p=args.top_p, top_k=args.top_k, max_new_tokens=min(args.max_new_tokens, args.max_length - input_ids.size(1)), \n                num_return_sequences=1, eos_token_id=[self.tokenizer.encode(args.eos_token, add_special_tokens=False)[-1], self.tokenizer.eos_token_id],\n            )\n            # For generation, directly return the text output\n            output_text = self.tokenizer.decode(outputs[0][input_ids.size(1):], skip_special_tokens=True).strip()\n            return output_text\n        else:\n            with torch.inference_mode():\n                self.model.eval()\n                logits = self.model(input_ids=input_ids).logits\n            labels = input_ids[0, 1:]\n            logits = logits[0, :-1] \n            log_probs = F.log_softmax(logits, dim=-1)\n\n            selected_log_probs = log_probs[torch.arange(len(labels)).to(labels.device), labels]\n            selected_log_probs = selected_log_probs.cpu().detach()\n            # Only return the option (candidate) part\n            return selected_log_probs[-option_len:]\n\n\n    def one_step_pred(self, train_samples, eval_sample, verbose=False):\n        \"\"\"\n        Return the prediction on the eval sample. In ICL, use train_samples as demonstrations\n        \"\"\"\n        verbose = verbose or self.args.verbose\n        if verbose:\n            logger.info(\"========= Example =========\")\n            logger.info(f\"Candidate: {eval_sample.candidates}\")\n            logger.info(f\"Correct candidate: {eval_sample.correct_candidate}\")\n\n\n        # Encode (add prompt and tokenize) the sample; if multiple-choice/classification, encode all candidates (options)\n        encoded_candidates, option_lens = encode_prompt(\n            self.task, self.task.get_template(), train_samples, eval_sample, self.tokenizer, max_length=self.args.max_length, \n            generation=self.task.generation, max_new_tokens=self.args.max_new_tokens\n        )\n\n        # Calibration\n        if self.args.sfc or self.args.icl_sfc:\n            sfc_encoded_candidates, sfc_option_lens = encode_prompt(self.task, self.task.get_template(), \n                train_samples, eval_sample, self.tokenizer, max_length=self.args.max_length,\n                sfc=self.args.sfc, icl_sfc=self.args.icl_sfc, generation=self.task.generation, \n                max_new_tokens=self.args.max_new_tokens\n            )\n\n        outputs = []\n        if self.task.generation:\n            # For generation tasks, return the autoregressively-generated text\n            output_text = self.forward(encoded_candidates[0], generation=True)\n            if verbose:\n                logger.info(\"=== Prompt ===\")\n                logger.info(self.tokenizer.decode(encoded_candidates[0]))\n                logger.info(f\"Output: {output_text}\") \n            return Prediction(correct_candidate=eval_sample.correct_candidate, predicted_candidate=output_text)\n        else:\n            # For classification/multiple-choice, calculate the probabilities of all candidates\n            for candidate_id, encoded_candidate in enumerate(encoded_candidates):\n                selected_log_probs = self.forward(encoded_candidate, option_len=option_lens[candidate_id])\n                if verbose:\n                    if candidate_id == 0:\n                        logger.info(\"=== Candidate %d ===\" % candidate_id)\n                        logger.info(self.tokenizer.decode(encoded_candidate))\n                    else:\n                        logger.info(\"=== Candidate %d (without context)===\" % candidate_id)\n                        logger.info(self.tokenizer.decode(encoded_candidate).split(self.task.train_sep)[-1])\n                    logger.info(f\"Log probabilities of the option tokens: {selected_log_probs}\")\n\n                if self.args.sfc or self.args.icl_sfc:\n                    sfc_selected_log_probs = self.forward(sfc_encoded_candidates[candidate_id], option_len=sfc_option_lens[candidate_id])\n                    if verbose:\n                        logger.info(\"=== Candidate %d (without context) SFC ===\" % candidate_id)\n                        logger.info(self.tokenizer.decode(sfc_encoded_candidates[candidate_id]).split(self.task.train_sep)[-1])\n                        logger.info(f\"Log probabilities of the option tokens: {sfc_selected_log_probs}\")\n\n                outputs.append({\"log_probs\": selected_log_probs, \"sfc_log_probs\": sfc_selected_log_probs if self.args.sfc or self.args.icl_sfc else None})\n\n            if self.args.sfc or self.args.icl_sfc:\n                # Calibrated probabilities (surface form competition; https://arxiv.org/pdf/2104.08315.pdf)\n                # log p(candidate | input) = log p_lm(candidate | input) - log p_lm(candidate | sfc prompt)\n                scores = [x['log_probs'].sum().item() - x['sfc_log_probs'].sum().item() for x in outputs]\n            else:\n                # (Default) length-normalized log probabilities\n                # log p(candidate | input) = log p_lm(candidate | input) / |candidate #tokens|\n                scores = [x['log_probs'].mean().item() for x in outputs]\n\n            if verbose:\n                logger.info(f\"Prediction scores: {scores}\")\n\n            if isinstance(eval_sample.correct_candidate, list):\n                # For some datasets there are multiple correct answers\n                correct_candidate_id = [eval_sample.candidates.index(c) for c in eval_sample.correct_candidate]\n            else:\n                correct_candidate_id = eval_sample.candidates.index(eval_sample.correct_candidate)\n\n            return Prediction(correct_candidate=correct_candidate_id, predicted_candidate=int(np.argmax(scores)))\n\n\n    def evaluate(self, train_samples, eval_samples, one_train_set_per_eval_sample=False):\n        \"\"\"\n        Evaluate function. If one_train_set_per_eval_sample is True, then each eval sample has its own training (demonstration) set.\n        \"\"\"\n        if one_train_set_per_eval_sample:\n            logger.info(f\"There are {len(eval_samples)} validation samples and one train set per eval sample\")\n        else:\n            logger.info(f\"There are {len(train_samples)} training samples and {len(eval_samples)} validation samples\")\n\n        # Prediction loop\n        predictions = []  \n        for eval_id, eval_sample in enumerate(tqdm(eval_samples)):\n            predictions.append(\n                self.one_step_pred(train_samples[eval_id] if one_train_set_per_eval_sample else train_samples, eval_sample, verbose=(eval_id < 3))\n            )\n\n        # Calculate metrics \n        metric_name = getattr(self.task, \"metric_name\", \"accuracy\")\n        metrics = {metric_name: calculate_metric(predictions, metric_name)}\n        return metrics\n\n\n    def train(self, train_samples, eval_samples):\n        \"\"\"\n        Training function\n        \"\"\"\n        # Set tokenizer to left padding (so that all the options are right aligned)\n        self.tokenizer.padding_side = \"left\"\n\n        class HFDataset(Dataset):\n\n            def __init__(self, data):\n                self.data = data\n\n            def __len__(self):\n                return len(self.data)\n\n            def __getitem__(self, idx):\n                return self.data[idx]\n\n\n        def _convert(samples):\n            \"\"\"\n            Convert samples to HF-compatible dataset\n            \"\"\"\n            data = []\n            for sample in samples:\n                encoded_candidates, option_lens = encode_prompt(\n                    self.task, self.task.get_template(), [], sample, self.tokenizer, \n                    max_length=self.args.max_length, generation=self.task.generation, generation_with_gold=True, \n                    max_new_tokens=self.args.max_new_tokens\n                )\n                if self.task.generation:\n                    correct_candidate_id = 0\n                elif isinstance(sample.correct_candidate, list):\n                    correct_candidate_id = sample.candidates.index(sample.correct_candidate[0])\n                else:\n                    correct_candidate_id = sample.candidates.index(sample.correct_candidate)\n                \n                if self.args.non_diff:\n                    # For non-differentiable objective, there is no teacher forcing thus the \n                    # current answer part is removed\n                    encoded_candidates[correct_candidate_id] = encoded_candidates[correct_candidate_id][:-option_lens[correct_candidate_id]]\n\n                if self.args.train_as_classification:\n                    # For classification, we provide the label as the correct candidate id\n                    data.append([{\"input_ids\": encoded_candidates[_i], \"labels\": correct_candidate_id, \"option_len\": option_lens[_i], \"num_options\": len(sample.candidates)} for _i in range(len(encoded_candidates))])\n                elif self.args.only_train_option:\n                    # Otherwise, it is just LM-style teacher forcing\n                    if self.args.non_diff:\n                        # For non-differentiable objective, we need to provide the gold answer to calculate F1/acc\n                        data.append({\"input_ids\": encoded_candidates[correct_candidate_id], \"labels\": encoded_candidates[correct_candidate_id], \"option_len\": option_lens[correct_candidate_id], \"gold\": sample.correct_candidate})\n                    else:\n                        data.append({\"input_ids\": encoded_candidates[correct_candidate_id], \"labels\": encoded_candidates[correct_candidate_id], \"option_len\": option_lens[correct_candidate_id]})\n                else:\n                    data.append({\"input_ids\": encoded_candidates[correct_candidate_id], \"labels\": encoded_candidates[correct_candidate_id]})\n            return data\n\n        with count_time(\"Tokenizing training samples\"):\n            train_dataset = HFDataset(_convert(train_samples))\n            eval_dataset = HFDataset(_convert(eval_samples))\n        \n        if self.args.only_train_option and not self.args.non_diff:\n        #     # If --only_train_option and not with a non-differentiable objective, we wrap the forward function\n        #     self.model.original_forward = self.model.forward\n        #     self.model.forward = forward_wrap_with_option_len.__get__(self.model, type(self.model))\n            # ZO2 added -> register custom loss functions\n            self.model.zo_custom_train_loss_fn = custom_loss_fn_with_option_len\n            self.model.zo_custom_eval_loss_fn = custom_loss_fn_with_option_len\n\n        if self.args.non_diff:\n            collator = NondiffCollator\n        else:\n            collator = DataCollatorForTokenClassification\n\n        # ZO2 added ->\n        trainer = ZOTrainer(\n            model=self.model, \n            args=self.args,\n            train_dataset=train_dataset, \n            eval_dataset=eval_dataset,\n            processing_class=self.tokenizer,\n            data_collator=DataCollatorWithPaddingAndNesting(self.tokenizer, pad_to_multiple_of=8) if self.args.train_as_classification else collator(self.tokenizer, pad_to_multiple_of=8),\n        )\n        if self.args.save_on_interrupt:\n            trainer.add_callback(SIGUSR1Callback())\n\n        # Resume training from a last checkpoint\n        last_checkpoint = None\n        from transformers.trainer_utils import get_last_checkpoint\n        if os.path.isdir(self.args.output_dir) and not self.args.overwrite_output_dir:\n            last_checkpoint = get_last_checkpoint(self.args.output_dir)\n        if last_checkpoint is not None and self.args.resume_from_checkpoint is None:\n            logger.info(\n                f\"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change \"\n                \"the `--output_dir` or add `--overwrite_output_dir` to train from scratch.\"\n            )\n        if self.args.resume_from_checkpoint is not None:\n            last_checkpoint = self.args.resume_from_checkpoint\n\n        trainer.train(resume_from_checkpoint=last_checkpoint) \n\n        # Explicitly save the model\n        if self.args.save_model:\n            logger.warn(\"Save model..\")\n            trainer.save_model()\n        \n        # FSDP compatibility\n        self.model = trainer.model \n        \n        # Reset the forward function for evaluation\n        if self.args.only_train_option and not self.args.non_diff:\n        #     if type(self.model) == FSDP:\n        #         logger.info(\"This is an FSDP model now. Be careful when assigning back the original forward function\")\n        #         self.model._fsdp_wrapped_module.forward = self.model._fsdp_wrapped_module.original_forward\n        #     else:\n        #         self.model.forward = self.model.original_forward\n            # ZO2 added -> remove the custom loss functions for evaluation\n            self.model.zo_custom_train_loss_fn = None\n            self.model.zo_custom_eval_loss_fn = None\n\n\ndef result_file_tag(args):\n    \"\"\"\n    Get the result file tag\n    \"\"\"\n    save_model_name = args.model_name.split(\"/\")[-1]\n    sfc_tag = \"-sfc\" if args.sfc else \"\"\n    icl_sfc_tag = \"-icl_sfc\" if args.icl_sfc else \"\"\n    sample_eval_tag = \"-sampleeval%d\" % args.num_eval if args.num_eval is not None else \"\"\n    sample_train_tag = \"-ntrain%d\" % args.num_train if args.num_train > 0 else \"\"\n    sample_dev_tag = \"-ndev%d\" % args.num_dev if args.num_dev is not None else \"\"\n    customized_tag = f\"-{args.tag}\" if len(args.tag) > 0 else \"\"\n    return f\"{args.task_name}-{save_model_name}\" + sfc_tag + icl_sfc_tag + sample_eval_tag + sample_train_tag + sample_dev_tag + customized_tag\n\n\ndef main():\n    args = parse_args()\n\n    set_seed(args.seed)\n    task = get_task(args.task_name)\n    train_sets = task.sample_train_sets(num_train=args.num_train, num_dev=args.num_dev, num_eval=args.num_eval, num_train_sets=args.num_train_sets, seed=args.train_set_seed)\n\n    # Initialize trainer and load model\n    framework = Framework(args, task)\n\n    if args.train_set_seed is not None or args.num_train_sets is not None:\n        # Eval samples share one (or multiple) training set(s)\n        for train_set_id, train_samples in enumerate(train_sets):\n            train_set_seed = train_set_id if args.train_set_seed is None else args.train_set_seed\n\n            # Sample eval samples\n            if args.num_eval is not None:\n                eval_samples = task.sample_subset(data_split=\"valid\", seed=train_set_seed, num=args.num_eval)\n            else:\n                eval_samples = task.valid_samples\n\n            if args.trainer != \"none\":\n                if args.num_dev is not None:\n                    # Dev samples\n                    dev_samples = train_samples[-args.num_dev:] \n                    train_samples = train_samples[:-args.num_dev]\n                else:\n                    dev_samples = None\n\n                # Training\n                framework.train(train_samples, dev_samples if dev_samples is not None else eval_samples)\n\n                if not args.no_eval:\n                    metrics = framework.evaluate([], eval_samples) # No in-context learning if there is training\n                    if dev_samples is not None:\n                        dev_metrics = framework.evaluate([], dev_samples) \n                        for m in dev_metrics:\n                            metrics[\"dev_\" + m] = dev_metrics[m]\n            else:\n                assert args.num_dev is None\n                # Zero-shot / in-context learning\n                metrics = framework.evaluate(train_samples, eval_samples)\n\n            if not args.no_eval:\n                logger.info(\"===== Train set %d =====\" % train_set_seed)\n                logger.info(metrics)\n                if args.local_rank <= 0:\n                    write_metrics_to_file(metrics, \"result/\" +  result_file_tag(args) + f\"-trainset{train_set_id}.json\" if args.result_file is None else args.result_file)\n\n    else:\n        # For each eval sample, there is a training set. no training is allowed\n        # This is for in-context learning (ICL)\n        assert args.trainer == \"none\"\n        if args.num_eval is not None:\n            eval_samples = task.sample_subset(data_split=\"valid\", seed=0, num=args.num_eval)\n        else:\n            eval_samples = task.valid_samples\n\n        metrics = framework.evaluate(train_sets, eval_samples, one_train_set_per_eval_sample=True)\n        logger.info(metrics)\n        if args.local_rank <= 0:\n            write_metrics_to_file(metrics, \"result/\" + result_file_tag(args) + \"-onetrainpereval.json\" if args.result_file is None else args.result_file)\n\nif __name__ == \"__main__\": \n    main()\n"
  },
  {
    "path": "example/mezo_runner/tasks.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n\"\"\"\nCopied https://github.com/princeton-nlp/MeZO/blob/main/large_models/tasks.py\n\"\"\"\n\nfrom templates import *\nfrom utils import temp_seed\nimport json\nimport os\nfrom datasets import load_dataset\nfrom dataclasses import dataclass\nfrom typing import List, Union\nimport string\nimport random\nimport datasets\nimport sys\nimport numpy as np\nimport logging\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n\n\ndef get_task(task_name):\n    aa = task_name.split(\"__\")\n    if len(aa) == 2:\n        task_group, subtask = aa\n    else:\n        task_group = aa[0]\n        subtask = None\n    class_ = getattr(sys.modules[__name__], f\"{task_group}Dataset\")\n    instance = class_(subtask)\n    return instance\n\n\n@dataclass\nclass Sample:\n    id: int = None\n    data: dict = None\n    correct_candidate: Union[str, List[str]] = None\n    candidates: List[str] = None\n\n\nclass Dataset:\n    mixed_set = False\n    train_sep = \"\\n\\n\"\n    generation = False # whether this is a generation task\n\n    def __init__(self, subtask=None, **kwargs) -> None:\n        self.subtask = subtask\n    \n    def get_task_name(self):\n        return self.subtask\n        \n    def load_dataset():\n        raise NotImplementedError\n    \n    def get_template(self, template_version=0):\n       templates = {0: Template}\n       return templates[template_version]\n   \n    def build_sample(self, example):\n        return \n     \n    def sample_train_sets(self, num_train=32, num_dev=None, num_eval=None, num_train_sets=None, seed=None):\n        if seed is not None:\n            # one train/demo set using the designated seed\n            seeds = [seed]\n        elif num_train_sets is not None:\n            # num_train_sets train/demo sets\n            seeds = list(range(num_train_sets))\n        else: \n            # one train/demo set per evaluation sample\n            assert num_dev is None # not supported\n            len_valid_samples = len(self.samples[\"valid\"]) if num_eval is None else num_eval\n            with temp_seed(0):\n                seeds = np.random.randint(0, 10000, len_valid_samples)\n\n        train_samples = [] \n        for i, set_seed in enumerate(seeds):\n            if self.mixed_set:\n                raise NotImplementedError\n                train_samples.append(self.sample_subset(data_split=\"valid\", seed=set_seed, num=num_train, exclude=i))\n            else:\n                if num_dev is not None:\n                    train_samples.append(self.sample_subset(data_split=\"train\", seed=set_seed, num=num_train+num_dev)) # dev set is included at the end of train set\n                    if num_train + num_dev > len(self.samples[\"train\"]):\n                        logger.warn(\"num_train + num_dev > available training examples\")\n                else:\n                    train_samples.append(self.sample_subset(data_split=\"train\", seed=set_seed, num=num_train))\n                if num_dev is not None:\n                    logger.info(f\"Sample train set {len(train_samples[-1])}/{len(self.samples['train'])}\")\n                    logger.info(f\"... including dev set {num_dev} samples\")\n        return train_samples\n\n    def sample_subset(self, data_split=\"train\", seed=0, num=100, exclude=None):\n        with temp_seed(seed):\n            samples = self.samples[data_split] \n            lens = len(samples)\n            index = np.random.permutation(lens).tolist()[:num if exclude is None else num+1]\n            if exclude is not None and exclude in index:\n                index.remove(exclude)\n            else:\n                index = index[:num]\n            return [samples[i] for i in index]\n    \n    @property\n    def valid_samples(self):\n        return self.samples[\"valid\"]\n\n\nclass SST2Dataset(Dataset):\n    train_sep = \"\\n\\n\"\n    def __init__(self, subtask=None, **kwargs) -> None:\n        self.load_dataset(subtask, **kwargs)\n        \n    def load_dataset(self, path, **kwargs):\n        d = load_dataset('glue', 'sst2')\n        train_d = d[\"train\"]\n        validation_d = d[\"validation\"]\n        \n        train_samples = [self.build_sample(example) for example in train_d]\n        valid_samples = [self.build_sample(example) for example in validation_d]\n        \n        self.samples = {\"train\": train_samples, \"valid\": valid_samples}\n    \n    # for generative tasks, candidates are []\n    def build_sample(self, example):\n        label = int(example[\"label\"])\n        return Sample(id=example[\"idx\"], data=example, correct_candidate=label, candidates=[0, 1])\n        \n    def get_template(self, template_version=0):\n        return {0: SST2Template}[template_version]()\n        \n    \nclass CopaDataset(Dataset):\n    train_sep = \"\\n\\n\"\n    mixed_set = False\n\n    def __init__(self, subtask=None, **kwargs) -> None:\n        self.load_dataset(subtask, **kwargs)\n        \n    def load_dataset(self, path, **kwargs):\n        train_examples = load_dataset('super_glue', \"copa\")[\"train\"]\n        valid_examples = load_dataset('super_glue', \"copa\")[\"validation\"]\n    \n        train_samples = [self.build_sample(example) for example in train_examples]\n        valid_samples = [self.build_sample(example) for example in valid_examples]\n        self.samples = {\"train\": train_samples, \"valid\": valid_samples}\n    \n    # for generative tasks, candidates are []\n    def build_sample(self, example):\n        sample = \\\n            Sample(\n                id=example[\"idx\"],\n                data=example,\n                candidates=[example[\"choice1\"], example[\"choice2\"]],\n                correct_candidate=example[f\"choice{example['label'] + 1}\"],\n            )\n        \n        return sample\n        \n    def get_template(self, template_version=0):\n        return {0: CopaTemplate}[template_version]()\n\n\nclass BoolQDataset(Dataset):\n    def __init__(self, subtask=None, **kwargs) -> None:\n        self.load_dataset(subtask, **kwargs)\n    \n    def load_dataset(self, path, **kwargs):\n        d = load_dataset(\"boolq\")\n        train_set = d[\"train\"]\n        valid_set = d[\"validation\"]\n\n        train_samples = [self.build_sample(example) for example in train_set]\n        valid_samples = [self.build_sample(example) for example in valid_set]\n        self.samples = {\"train\": train_samples, \"valid\": valid_samples}\n    \n    def build_sample(self, example):\n        sample = \\\n            Sample(\n                data=example,\n                candidates=[\"Yes\", \"No\"],\n                correct_candidate=\"Yes\" if example[\"answer\"] else \"No\",\n            )\n        \n        return sample\n    \n    def get_template(self, template_version=2):\n        return {0: BoolQTemplate, 1: BoolQTemplateV2, 2: BoolQTemplateV3}[template_version]()\n\n\nclass MultiRCDataset(Dataset):\n    \n    def __init__(self, subtask=None, **kwargs) -> None:\n        self.load_dataset(subtask, **kwargs)\n    \n    def load_dataset(self, path, **kwargs):\n        d = load_dataset(\"super_glue\", \"multirc\")\n        train_set = d[\"train\"]\n        valid_set = d[\"validation\"]\n\n        train_samples = [self.build_sample(example) for example in train_set]\n        valid_samples = [self.build_sample(example) for example in valid_set]\n        self.samples = {\"train\": train_samples, \"valid\": valid_samples}\n    \n    def build_sample(self, example):\n        sample = \\\n            Sample(\n                data=example,\n                candidates=[0, 1],\n                correct_candidate=example['label']\n            )\n        \n        return sample\n    \n    def get_template(self, template_version=0):\n        return {0: MultiRCTemplate}[template_version]()\n\n\nclass CBDataset(Dataset):\n    \n    def __init__(self, subtask=None, **kwargs) -> None:\n        self.load_dataset(subtask, **kwargs)\n    \n    def load_dataset(self, path, **kwargs):\n        d = load_dataset(\"super_glue\", \"cb\")\n        train_set = d[\"train\"]\n        valid_set = d[\"validation\"]\n\n        train_samples = [self.build_sample(example) for example in train_set]\n        valid_samples = [self.build_sample(example) for example in valid_set]\n        self.samples = {\"train\": train_samples, \"valid\": valid_samples}\n    \n    def build_sample(self, example):\n        sample = \\\n            Sample(\n                data=example,\n                candidates=[0, 1, 2],\n                correct_candidate=example['label']\n            )\n        \n        return sample\n    \n    def get_template(self, template_version=0):\n        return {0: CBTemplate}[template_version]()\n\n\nclass WICDataset(Dataset):\n    \n    def __init__(self, subtask=None, **kwargs) -> None:\n        self.load_dataset(subtask, **kwargs)\n    \n    def load_dataset(self, path, **kwargs):\n        d = load_dataset(\"super_glue\", \"wic\")\n        train_set = d[\"train\"]\n        valid_set = d[\"validation\"]\n\n        train_samples = [self.build_sample(example) for example in train_set]\n        valid_samples = [self.build_sample(example) for example in valid_set]\n        self.samples = {\"train\": train_samples, \"valid\": valid_samples}\n    \n    def build_sample(self, example):\n        sample = \\\n            Sample(\n                data=example,\n                candidates=[0, 1],\n                correct_candidate=example['label']\n            )\n        \n        return sample\n    \n    def get_template(self, template_version=0):\n        return {0: WICTemplate}[template_version]()\n\n\nclass WSCDataset(Dataset):\n    \n    def __init__(self, subtask=None, **kwargs) -> None:\n        self.load_dataset(subtask, **kwargs)\n    \n    def load_dataset(self, path, **kwargs):\n        d = load_dataset(\"super_glue\", \"wsc.fixed\")\n        train_set = d[\"train\"]\n        valid_set = d[\"validation\"]\n\n        train_samples = [self.build_sample(example) for example in train_set]\n        valid_samples = [self.build_sample(example) for example in valid_set]\n        self.samples = {\"train\": train_samples, \"valid\": valid_samples}\n    \n    def build_sample(self, example):\n        sample = \\\n            Sample(\n                data=example,\n                candidates=[0, 1],\n                correct_candidate=example['label']\n            )\n        \n        return sample\n    \n    def get_template(self, template_version=0):\n        return {0: WSCTemplate}[template_version]()\n\n\nclass ReCoRDDataset(Dataset):\n    \n    def __init__(self, subtask=None, **kwargs) -> None:\n        self.load_dataset(subtask, **kwargs)\n    \n    def load_dataset(self, path, **kwargs):\n        d = load_dataset(\"super_glue\", \"record\")\n        train_set = d[\"train\"]\n        valid_set = d[\"validation\"]\n\n        train_samples = [self.build_sample(example) for example in train_set]\n        valid_samples = [self.build_sample(example) for example in valid_set]\n        self.samples = {\"train\": train_samples, \"valid\": valid_samples}\n    \n    def build_sample(self, example):\n        sample = \\\n            Sample(\n                data=example,\n                candidates=example['entities'],\n                correct_candidate=example['answers']\n            )\n        \n        return sample\n    \n    def get_template(self, template_version=0):\n        return {0: ReCoRDTemplateGPT3}[template_version]()\n\n\nclass RTEDataset(Dataset):\n    \n    def __init__(self, subtask=None, **kwargs) -> None:\n        self.load_dataset(subtask, **kwargs)\n    \n    def load_dataset(self, path, **kwargs):\n        d = load_dataset(\"super_glue\", \"rte\")\n        train_set = d[\"train\"]\n        valid_set = d[\"validation\"]\n\n        train_samples = [self.build_sample(example) for example in train_set]\n        valid_samples = [self.build_sample(example) for example in valid_set]\n        self.samples = {\"train\": train_samples, \"valid\": valid_samples}\n    \n    def build_sample(self, example):\n        sample = \\\n            Sample(\n                data=example,\n                candidates=[0, 1],\n                correct_candidate=example['label']\n            )\n        \n        return sample\n    \n    def get_template(self, template_version=0):\n        return {0: RTETemplate}[template_version]()\n\n \nclass SQuADDataset(Dataset):\n    metric_name = \"f1\"\n    generation = True\n\n    def __init__(self, subtask=None, **kwargs) -> None:\n        self.load_dataset()\n        \n    def load_dataset(self):\n        dataset = load_dataset(\"squad\")\n        train_examples = dataset[\"train\"]\n        valid_examples = dataset[\"validation\"]\n\n        train_samples = [self.build_sample(example, idx) for idx, example in enumerate(train_examples)]\n        valid_samples = [self.build_sample(example, idx) for idx, example in enumerate(valid_examples)]\n        self.samples = {\"train\": train_samples, \"valid\": valid_samples}\n    \n    # for generative tasks, candidates are []\n    def build_sample(self, example, idx):\n        answers = example['answers']['text']\n        assert len(answers) > 0\n        return Sample(\n            id=idx,\n            data={\n                \"title\": example['title'],\n                \"context\": example['context'],\n                \"question\": example['question'],\n                \"answers\": answers\n            },\n            candidates=None,\n            correct_candidate=answers\n        )\n        \n    def get_template(self, template_version=0):\n        return {0: SQuADv2Template}[template_version]()\n\n\nclass DROPDataset(Dataset):\n    metric_name = \"f1\"\n    generation = True\n\n    def __init__(self, subtask=None, **kwargs) -> None:\n        self.load_dataset()\n        \n    def load_dataset(self):\n        dataset = load_dataset(\"drop\")\n        train_examples = dataset[\"train\"]\n        valid_examples = dataset[\"validation\"]\n\n        train_samples = [self.build_sample(example, idx) for idx, example in enumerate(train_examples)]\n        valid_samples = [self.build_sample(example, idx) for idx, example in enumerate(valid_examples)]\n        self.samples = {\"train\": train_samples, \"valid\": valid_samples}\n    \n    # for generative tasks, candidates are []\n    def build_sample(self, example, idx):\n        answers = example['answers_spans']['spans']\n        assert len(answers) > 0\n        return Sample(\n            id=idx,\n            data={\n                \"context\": example['passage'],\n                \"question\": example['question'],\n                \"answers\": answers\n            },\n            candidates=None,\n            correct_candidate=answers\n        )\n        \n    def get_template(self, template_version=0):\n        return {0: DROPTemplate}[template_version]()\n"
  },
  {
    "path": "example/mezo_runner/templates.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n\"\"\"\nCopied https://github.com/princeton-nlp/MeZO/blob/main/large_models/templates.py\n\"\"\"\n\nclass Template:\n    def encode(self, sample):\n        \"\"\"\n        Return prompted version of the example (without the answer/candidate)\n        \"\"\"\n        raise NotImplementedError\n    \n    def verbalize(self, sample, candidate):\n        \"\"\"\n        Return the prompted version of the example (with the answer/candidate)\n        \"\"\"\n        return candidate\n    \n    def encode_sfc(self, sample):\n        \"\"\"\n        Same as encode, but for SFC (calibration) -- this usually means the input is not included\n        \"\"\"\n        return \"<mask>\"\n    \n    def verbalize_sfc(self, sample, candidate):\n        \"\"\"\n        Same as verbalize, but for SFC (calibration) -- this usually means the input is not included\n        \"\"\"\n        return candidate\n\n\nclass SST2Template(Template):\n    verbalizer = {0: \"terrible\", 1: \"great\"}\n    def encode(self, sample):\n        text = sample.data[\"sentence\"].strip()\n        return f\"{text} It was\"\n\n    def verbalize(self, sample, candidate):\n        text = sample.data[\"sentence\"].strip()\n        return f\"{text} It was {self.verbalizer[candidate]}\"\n    \n    def encode_sfc(self, sample):\n        return f\" It was\"\n\n    def verbalize_sfc(self, sample, candidate):\n        return f\" It was {self.verbalizer[candidate]}\"\n\n\nclass CopaTemplate(Template):\n    capitalization: str = \"correct\"\n    effect_conj: str = \" so \"\n    cause_conj: str = \" because \"\n\n    def get_conjucture(self, sample):\n        if sample.data[\"question\"] == \"effect\":\n            conjunction = self.effect_conj\n        elif sample.data[\"question\"] == \"cause\":\n            conjunction = self.cause_conj\n        else:\n            raise NotImplementedError\n        return conjunction\n    \n    def get_prompt(self, sample):\n        premise = sample.data[\"premise\"].rstrip()\n        if premise.endswith(\".\"):  # TODO Add other scripts with different punctuation\n            premise = premise[:-1]\n        conjunction = self.get_conjucture(sample)\n        prompt = premise + conjunction\n        if self.capitalization == \"upper\":\n            prompt = prompt.upper()\n        elif self.capitalization == \"lower\":\n            prompt = prompt.lower()\n        return prompt\n\n    def encode(self, sample):\n        prompt = self.get_prompt(sample)\n        return prompt \n\n    def capitalize(self, c):\n        if self.capitalization == \"correct\":\n            words = c.split(\" \")\n            if words[0] != \"I\":\n                words[0] = words[0].lower()\n            return \" \".join(words)\n        elif self.capitalization == \"bug\":\n            return c\n        elif self.capitalization == \"upper\":\n            return c.upper()\n        elif self.capitalization == \"lower\":\n            return c.lower()\n        else:\n            raise NotImplementedError\n            \n    def verbalize(self, sample, candidate):\n        prompt = self.get_prompt(sample)\n        return prompt + self.capitalize(candidate)\n    \n    def encode_sfc(self, sample):\n        conjunction = self.get_conjucture(sample)\n        return conjunction.strip() \n\n    def verbalize_sfc(self, sample, candidate):\n        conjunction = self.get_conjucture(sample)\n        sfc_prompt = conjunction.strip() + \" \" + self.capitalize(candidate)\n        return sfc_prompt\n        \n    \nclass BoolQTemplate(Template):\n    def encode(self, sample):\n        passage = sample.data[\"passage\"]\n        question = sample.data[\"question\"]\n        if not question.endswith(\"?\"):\n            question = question + \"?\"\n        question = question[0].upper() + question[1:]\n        return f\"{passage} {question}\"\n\n    def verbalize(self, sample, candidate):\n        passage = sample.data[\"passage\"]\n        question = sample.data[\"question\"]\n        if not question.endswith(\"?\"):\n            question = question + \"?\"\n        question = question[0].upper() + question[1:]\n        return f\"{passage} {question} {candidate}\"\n    \n    def encode_sfc(self, sample):\n        return \"\"\n    \n    def verbalize_sfc(self, sample, candidate):\n        return candidate\n\n\nclass BoolQTemplateV2(Template):\n    def encode(self, sample):\n        passage = sample.data[\"passage\"]\n        question = sample.data[\"question\"]\n        if not question.endswith(\"?\"):\n            question = question + \"?\"\n        question = question[0].upper() + question[1:]\n        return f\"{passage} {question}\\\\n\\\\n\"\n\n    def verbalize(self, sample, candidate):\n        passage = sample.data[\"passage\"]\n        question = sample.data[\"question\"]\n        if not question.endswith(\"?\"):\n            question = question + \"?\"\n        question = question[0].upper() + question[1:]\n        return f\"{passage} {question}\\\\n\\\\n{candidate}\"\n    \n    def encode_sfc(self, sample):\n        return \"\"\n    \n    def verbalize_sfc(self, sample, candidate):\n        return candidate\n\n\nclass BoolQTemplateV3(Template):\n    def encode(self, sample):\n        passage = sample.data[\"passage\"]\n        question = sample.data[\"question\"]\n        if not question.endswith(\"?\"):\n            question = question + \"?\"\n        question = question[0].upper() + question[1:]\n        return f\"{passage} {question}\\n\"\n\n    def verbalize(self, sample, candidate):\n        passage = sample.data[\"passage\"]\n        question = sample.data[\"question\"]\n        if not question.endswith(\"?\"):\n            question = question + \"?\"\n        question = question[0].upper() + question[1:]\n        return f\"{passage} {question}\\n{candidate}\"\n    \n    def encode_sfc(self, sample):\n        return \"\"\n    \n    def verbalize_sfc(self, sample, candidate):\n        return candidate\n    \n\nclass MultiRCTemplate(Template):\n    # From PromptSource 1\n    verbalizer = {0: \"No\", 1: \"Yes\"}\n\n    def encode(self, sample):\n        paragraph = sample.data[\"paragraph\"]\n        question = sample.data[\"question\"]\n        answer = sample.data[\"answer\"]\n        return f\"{paragraph}\\nQuestion: {question}\\nI found this answer \\\"{answer}\\\". Is that correct? Yes or No?\\n\"\n\n    def verbalize(self, sample, candidate):\n        paragraph = sample.data[\"paragraph\"]\n        question = sample.data[\"question\"]\n        answer = sample.data[\"answer\"]\n        return f\"{paragraph}\\nQuestion: {question}\\nI found this answer \\\"{answer}\\\". Is that correct? Yes or No?\\n{self.verbalizer[candidate]}\"\n\n    def encode_sfc(self, sample):\n        return f\"\"\n\n    def verbalize_sfc(self, sample, candidate):\n        return f\"{self.verbalizer[candidate]}\"\n\n    \nclass CBTemplate(Template):\n    # From PromptSource 1\n    verbalizer = {0: \"Yes\", 1: \"No\", 2: \"Maybe\"}\n\n    def encode(self, sample):\n        premise = sample.data[\"premise\"]\n        hypothesis = sample.data[\"hypothesis\"]\n        return f\"Suppose {premise} Can we infer that \\\"{hypothesis}\\\"? Yes, No, or Maybe?\\n\"\n\n    def verbalize(self, sample, candidate):\n        premise = sample.data[\"premise\"]\n        hypothesis = sample.data[\"hypothesis\"]\n        return f\"Suppose {premise} Can we infer that \\\"{hypothesis}\\\"? Yes, No, or Maybe?\\n{self.verbalizer[candidate]}\"\n\n    def encode_sfc(self, sample):\n        return f\"\"\n\n    def verbalize_sfc(self, sample, candidate):\n        return f\"{self.verbalizer[candidate]}\"\n\n\nclass WICTemplate(Template):\n    # From PromptSource 1\n    verbalizer = {0: \"No\", 1: \"Yes\"}\n\n    def encode(self, sample):\n        sent1 = sample.data[\"sentence1\"]\n        sent2 = sample.data[\"sentence2\"]\n        word = sample.data[\"word\"]\n        return f\"Does the word \\\"{word}\\\" have the same meaning in these two sentences? Yes, No?\\n{sent1}\\n{sent2}\\n\"\n\n    def verbalize(self, sample, candidate):\n        sent1 = sample.data[\"sentence1\"]\n        sent2 = sample.data[\"sentence2\"]\n        word = sample.data[\"word\"]\n        return f\"Does the word \\\"{word}\\\" have the same meaning in these two sentences? Yes, No?\\n{sent1}\\n{sent2}\\n{self.verbalizer[candidate]}\"\n\n    def encode_sfc(self, sample):\n        return f\"\"\n\n    def verbalize_sfc(self, sample, candidate):\n        return f\"{self.verbalizer[candidate]}\"\n\n\nclass WSCTemplate(Template):\n    # From PromptSource 1\n    verbalizer = {0: \"No\", 1: \"Yes\"}\n\n    def encode(self, sample):\n        text = sample.data['text']\n        span1 = sample.data['span1_text']\n        span2 = sample.data['span2_text']\n        return f\"{text}\\nIn the previous sentence, does the pronoun \\\"{span2.lower()}\\\" refer to {span1}? Yes or No?\\n\"\n\n    def verbalize(self, sample, candidate):\n        text = sample.data['text']\n        span1 = sample.data['span1_text']\n        span2 = sample.data['span2_text']\n        return f\"{text}\\nIn the previous sentence, does the pronoun \\\"{span2.lower()}\\\" refer to {span1}? Yes or No?\\n{self.verbalizer[candidate]}\"\n\n    def encode_sfc(self, sample):\n        return f\"\"\n\n    def verbalize_sfc(self, sample, candidate):\n        return f\"{self.verbalizer[candidate]}\"\n\n\nclass ReCoRDTemplate(Template):\n    # From PromptSource 1 but modified\n\n    def encode(self, sample):\n        passage = sample.data['passage']\n        query = sample.data['query']\n        return f\"{passage}\\n{query}\\nQuestion: what is the \\\"@placeholder\\\"\\nAnswer:\"\n\n    def verbalize(self, sample, candidate):\n        passage = sample.data['passage']\n        query = sample.data['query']\n        return f\"{passage}\\n{query}\\nQuestion: what is the \\\"@placeholder\\\"\\nAnswer: {candidate}\"\n\n    def encode_sfc(self, sample):\n        return f\"Answer:\"\n\n    def verbalize_sfc(self, sample, candidate):\n        return f\"Answer: {candidate}\"\n\n\nclass ReCoRDTemplateGPT3(Template):\n    # From PromptSource 1 but modified\n\n    def encode(self, sample):\n        passage = sample.data['passage'].replace(\"@highlight\\n\", \"- \")\n        return f\"{passage}\\n-\"\n\n    def verbalize(self, sample, candidate):\n        passage = sample.data['passage'].replace(\"@highlight\\n\", \"- \")\n        query = sample.data['query'].replace(\"@placeholder\", candidate[0] if isinstance(candidate, list) else candidate)\n        return f\"{passage}\\n- {query}\"\n\n        # passage = sample.data['passage']\n        # query = sample.data['query']\n        # return f\"{passage}\\n{query}\\nQuestion: what is the \\\"@placeholder\\\"\\nAnswer: {candidate}\"\n\n    def encode_sfc(self, sample):\n        return f\"-\"\n\n    def verbalize_sfc(self, sample, candidate):\n        query = sample.data['query'].replace(\"@placeholder\", candidate[0] if isinstance(candidate, list) else candidate)\n        return f\"- {query}\"\n\n\nclass RTETemplate(Template):\n    # From PromptSource 1\n    verbalizer={0: \"Yes\", 1: \"No\"}\n\n    def encode(self, sample):\n        premise = sample.data['premise']\n        hypothesis = sample.data['hypothesis']\n        return f\"{premise}\\nDoes this mean that \\\"{hypothesis}\\\" is true? Yes or No?\\n\"\n\n    def verbalize(self, sample, candidate):\n        premise = sample.data['premise']\n        hypothesis = sample.data['hypothesis']\n        return f\"{premise}\\nDoes this mean that \\\"{hypothesis}\\\" is true? Yes or No?\\n{self.verbalizer[candidate]}\"\n\n    def encode_sfc(self, sample):\n        return f\"\"\n\n    def verbalize_sfc(self, sample, candidate):\n        return f\"{self.verbalizer[candidate]}\"\n\n\nclass SQuADv2Template(Template):\n\n    def encode(self, sample):\n        question = sample.data['question'].strip()\n        title = sample.data['title']\n        context = sample.data['context']\n        answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one\n\n        return f\"Title: {title}\\nContext: {context}\\nQuestion: {question}\\nAnswer:\"\n\n    def verbalize(self, sample, candidate):\n        question = sample.data['question'].strip()\n        title = sample.data['title']\n        context = sample.data['context']\n        answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one\n\n        return f\"Title: {title}\\nContext: {context}\\nQuestion: {question}\\nAnswer: {answer}\\n\"\n\n    \n    def encode_sfc(self, sample):\n        raise NotImplementedError\n\n    def verbalize_sfc(self, sample, candidate):\n        raise NotImplementedError\n\n\nclass DROPTemplate(Template):\n\n    def encode(self, sample):\n        question = sample.data['question'].strip()\n        # title = sample.data['title']\n        context = sample.data['context']\n        answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one\n\n        return f\"Passage: {context}\\nQuestion: {question}\\nAnswer:\"\n\n    def verbalize(self, sample, candidate):\n        question = sample.data['question'].strip()\n        # title = sample.data['title']\n        context = sample.data['context']\n        answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one\n\n        return f\"Passage: {context}\\nQuestion: {question}\\nAnswer: {answer}\\n\"\n\n    \n    def encode_sfc(self, sample):\n        raise NotImplementedError\n\n    def verbalize_sfc(self, sample, candidate):\n        raise NotImplementedError\n"
  },
  {
    "path": "example/mezo_runner/utils.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n\"\"\"\nModified from https://github.com/princeton-nlp/MeZO/blob/main/large_models/utils.py\n\"\"\"\n\nimport json\nimport os\nimport contextlib\nfrom typing import Optional, Union\nimport numpy as np\nfrom dataclasses import dataclass, is_dataclass, asdict\nimport logging\nimport time\nfrom torch.nn import CrossEntropyLoss\nimport torch.nn.functional as F\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\nimport torch\nfrom transformers.utils import PaddingStrategy\nfrom transformers import PreTrainedTokenizerBase\nfrom transformers.data.data_collator import DataCollatorMixin\nimport transformers\nfrom typing import Optional, Union, List, Dict, Any\nimport signal\nfrom subprocess import call\nfrom collections.abc import Mapping\nfrom typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union\nInputDataClass = NewType(\"InputDataClass\", Any)\nfrom dataclasses import dataclass\nfrom transformers.tokenization_utils_base import PreTrainedTokenizerBase\n\nlogger = logging.getLogger(__name__)\n\n\n\ndef custom_loss_fn_with_option_len(self, input_ids, logits, labels, option_len=None, num_options=None):\n    \"\"\"\n    Modified from below 'forward_wrap_with_option_len'.\n    \"\"\"\n    loss = None\n    # Shift so that tokens < n predict n\n    shift_logits = logits[..., :-1, :].contiguous()\n    # Here we use input_ids (which should always = labels) bc sometimes labels are correct candidate IDs\n    shift_labels = torch.clone(input_ids)[..., 1:].contiguous()\n    shift_labels[shift_labels == self.config.pad_token_id] = -100\n\n    # Apply option len (do not calculate loss on the non-option part)\n    for _i, _len in enumerate(option_len):\n        shift_labels[_i, :-_len] = -100\n\n    # Calculate the loss\n    loss_fct = CrossEntropyLoss(ignore_index=-100)\n    if num_options is not None: \n        # Train as a classification tasks\n        log_probs = F.log_softmax(shift_logits, dim=-1)\n        mask = shift_labels != -100 # Option part\n        shift_labels[~mask] = 0 # So that it doesn't mess up with indexing\n\n        selected_log_probs = torch.gather(log_probs, dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(-1) # (bsz x num_options, len)\n        selected_log_probs = (selected_log_probs * mask).sum(-1) / mask.sum(-1) # (bsz x num_options)\n\n        if any([x != num_options[0] for x in num_options]):\n            # Multi choice tasks with different number of options\n            loss = 0\n            start_id = 0\n            count = 0\n            while start_id < len(num_options):\n                end_id = start_id + num_options[start_id]\n                _logits = selected_log_probs[start_id:end_id].unsqueeze(0) # (1, num_options)\n                _labels = labels[start_id:end_id][0].unsqueeze(0) # (1)\n                loss = loss_fct(_logits, _labels) + loss\n                count += 1\n                start_id = end_id\n            loss = loss / count\n        else:\n            num_options = num_options[0]\n            selected_log_probs = selected_log_probs.view(-1, num_options) # (bsz, num_options)\n            labels = labels.view(-1, num_options)[:, 0] # Labels repeat so we only take the first one\n            loss = loss_fct(selected_log_probs, labels)\n    else:\n        loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))\n\n    return loss\n\n\ndef forward_wrap_with_option_len(self, input_ids=None, labels=None, option_len=None, num_options=None, return_dict=None, **kwargs):\n    \"\"\"\n    This is to replace the original forward function of Transformer models to enable:\n    (1) Partial target sequence: loss will only be calculated on part of the sequence\n    (2) Classification-style training: a classification loss (CE) will be calculated over several options\n    Input:\n    - input_ids, labels: same as the original forward function\n    - option_len: a list of int indicating the option lengths, and loss will be calculated only on the\n      last option_len tokens \n    - num_options: a list of int indicating the number of options for each example (this will be #label\n      words for classification tasks and #choices for multiple choice tasks), and a classification loss\n      will be calculated.\n    \"\"\"\n    outputs = self.original_forward(input_ids=input_ids, **kwargs)\n    if labels is None:\n        return outputs\n    logits = outputs.logits\n\n    loss = None\n    # Shift so that tokens < n predict n\n    shift_logits = logits[..., :-1, :].contiguous()\n    # Here we use input_ids (which should always = labels) bc sometimes labels are correct candidate IDs\n    shift_labels = torch.clone(input_ids)[..., 1:].contiguous()\n    shift_labels[shift_labels == self.config.pad_token_id] = -100\n\n    # Apply option len (do not calculate loss on the non-option part)\n    for _i, _len in enumerate(option_len):\n        shift_labels[_i, :-_len] = -100\n\n    # Calculate the loss\n    loss_fct = CrossEntropyLoss(ignore_index=-100)\n    if num_options is not None: \n        # Train as a classification tasks\n        log_probs = F.log_softmax(shift_logits, dim=-1)\n        mask = shift_labels != -100 # Option part\n        shift_labels[~mask] = 0 # So that it doesn't mess up with indexing\n\n        selected_log_probs = torch.gather(log_probs, dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(-1) # (bsz x num_options, len)\n        selected_log_probs = (selected_log_probs * mask).sum(-1) / mask.sum(-1) # (bsz x num_options)\n\n        if any([x != num_options[0] for x in num_options]):\n            # Multi choice tasks with different number of options\n            loss = 0\n            start_id = 0\n            count = 0\n            while start_id < len(num_options):\n                end_id = start_id + num_options[start_id]\n                _logits = selected_log_probs[start_id:end_id].unsqueeze(0) # (1, num_options)\n                _labels = labels[start_id:end_id][0].unsqueeze(0) # (1)\n                loss = loss_fct(_logits, _labels) + loss\n                count += 1\n                start_id = end_id\n            loss = loss / count\n        else:\n            num_options = num_options[0]\n            selected_log_probs = selected_log_probs.view(-1, num_options) # (bsz, num_options)\n            labels = labels.view(-1, num_options)[:, 0] # Labels repeat so we only take the first one\n            loss = loss_fct(selected_log_probs, labels)\n    else:\n        loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))\n\n    if not return_dict:\n        output = (logits,) + outputs[1:]\n        return (loss,) + output if loss is not None else output\n\n    return CausalLMOutputWithPast(\n        loss=loss,\n        logits=logits,\n        past_key_values=outputs.past_key_values,\n        hidden_states=outputs.hidden_states,\n        attentions=outputs.attentions,\n    )\n\n\ndef encode_prompt(task, template, train_samples, eval_sample, tokenizer, max_length, sfc=False, icl_sfc=False, generation=False, generation_with_gold=False, max_new_tokens=None):\n    \"\"\"\n    Encode prompts for eval_sample\n    Input: \n    - task, template: task and template class\n    - train_samples, eval_sample: demonstrations and the actual sample\n    - tokenizer, max_length: tokenizer and max length\n    - sfc: generate prompts for calibration (surface form competition; https://arxiv.org/abs/2104.08315)\n    - icl_sfc: generate prompts for ICL version calibration\n    - generation: whether it is an generation task\n    - generation_with_gold: whether to include the generation-task gold answers (for training)\n    - max_new_tokens: max number of new tokens to generate so that we can save enough space \n      (only for generation tasks)\n    Output:\n    - encodings: a list of N lists of tokens. N is the number of options for classification/multiple-choice.\n    - option_lens: a list of N integers indicating the number of option tokens.\n    \"\"\"\n\n    # Demonstrations for ICL\n    train_prompts = [template.verbalize(sample, sample.correct_candidate).strip() for sample in train_samples]\n    train_prompts = task.train_sep.join(train_prompts).strip()\n    \n    # sfc or icl_sfc indicates that this example is used for calibration\n    if sfc or icl_sfc:\n        encode_fn = template.encode_sfc; verbalize_fn = template.verbalize_sfc\n    else: \n        encode_fn = template.encode; verbalize_fn = template.verbalize \n            \n    unverbalized_eval_prompt = encode_fn(eval_sample).strip(' ')\n    if not generation:\n        # We generate one prompt for each candidate (different classes in classification)\n        # or different choices in multiple-choice tasks\n        verbalized_eval_prompts = [verbalize_fn(eval_sample, cand).strip(' ') for cand in eval_sample.candidates]\n        unverbalized_eval_prompt_length = len(tokenizer.encode(unverbalized_eval_prompt))\n        option_lens = [(len(tokenizer.encode(verbalized_eval_prompt)) - unverbalized_eval_prompt_length) for verbalized_eval_prompt in verbalized_eval_prompts]\n\n        if sfc:\n            # Without demonstrations\n            final_prompts = verbalized_eval_prompts \n        else:\n            # With demonstrations\n            final_prompts = [(train_prompts + task.train_sep + eval_prompt).lstrip().strip(' ') for eval_prompt in verbalized_eval_prompts] \n    else:\n        assert not sfc and not icl_sfc, \"Generation tasks do not support SFC\"\n        if generation_with_gold:\n            verbalized_eval_prompts = [verbalize_fn(eval_sample, eval_sample.correct_candidate)]\n            unverbalized_eval_prompt_length = len(tokenizer.encode(unverbalized_eval_prompt))\n            option_lens = [(len(tokenizer.encode(verbalized_eval_prompt)) - unverbalized_eval_prompt_length) for verbalized_eval_prompt in verbalized_eval_prompts]\n            final_prompts = [(train_prompts + task.train_sep + eval_prompt).lstrip().strip(' ') for eval_prompt in verbalized_eval_prompts] \n        else:\n            option_lens = [0]\n            final_prompts = [(train_prompts + task.train_sep + unverbalized_eval_prompt).lstrip().strip(' ')]\n\n    # Tokenize \n    encodings = [tokenizer.encode(final_prompt) for final_prompt in final_prompts]\n\n    # Truncate (left truncate as demonstrations are less important)\n    if generation and max_new_tokens is not None:\n        max_length = max_length - max_new_tokens\n\n    if any([len(encoding) > max_length for encoding in encodings]):\n        logger.warn(\"Exceed max length\")\n    if tokenizer.add_bos_token:\n        encodings = [encoding[0:1] + encoding[1:][-(max_length-1):] for encoding in encodings]  \n    else:\n        encodings = [encoding[-max_length:] for encoding in encodings]  \n   \n    return encodings, option_lens\n \n\n\n@dataclass\nclass ICLCollator:\n    \"\"\"\n    Collator for ICL\n    \"\"\"\n    tokenizer: PreTrainedTokenizerBase\n\n    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:\n        if not isinstance(features[0], Mapping):\n            features = [vars(f) for f in features]\n        first = features[0]\n        batch = {}\n        \n        pad_id = self.tokenizer.pad_token_id\n\n        pad_ids = {\"input_ids\": pad_id, \"attention_mask\": 0, \"sfc_input_ids\": pad_id, \"sfc_attention_mask\": 0, \"labels\": pad_id}\n        for key in first:\n            pp = pad_ids[key]\n            lens = [len(f[key]) for f in features]\n            max_len = max(lens)\n            feature = np.stack([np.pad(f[key], (0, max_len - lens[i]), \"constant\", constant_values=(0, pp)) for i, f in enumerate(features)])\n            padded_feature = torch.from_numpy(feature).long()\n            batch[key] = padded_feature\n            \n        return batch\n\n\n@dataclass\nclass DataCollatorWithPaddingAndNesting:\n    \"\"\"\n    Collator for training\n    \"\"\"\n\n    tokenizer: PreTrainedTokenizerBase\n    padding: Union[bool, str, PaddingStrategy] = True\n    max_length: Optional[int] = None\n    pad_to_multiple_of: Optional[int] = None\n    return_tensors: str = \"pt\"\n\n    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:\n        features = [ff for f in features for ff in f]\n        batch = self.tokenizer.pad(\n            features,\n            padding=self.padding,\n            max_length=self.max_length,\n            pad_to_multiple_of=self.pad_to_multiple_of,\n            return_tensors=self.return_tensors,\n        )\n        if \"label\" in batch:\n            batch[\"labels\"] = batch[\"label\"]\n            del batch[\"label\"]\n        if \"label_ids\" in batch:\n            batch[\"labels\"] = batch[\"label_ids\"]\n            del batch[\"label_ids\"]\n        return batch\n\n\n@dataclass\nclass NondiffCollator(DataCollatorMixin):\n    \"\"\"\n    Collator for non-differentiable objectives\n    \"\"\"\n    tokenizer: PreTrainedTokenizerBase\n    padding: Union[bool, str, PaddingStrategy] = True\n    max_length: Optional[int] = None\n    pad_to_multiple_of: Optional[int] = None\n    label_pad_token_id: int = -100\n    return_tensors: str = \"pt\"\n\n    def torch_call(self, features):\n        import torch\n\n        label_name = \"label\" if \"label\" in features[0].keys() else \"labels\"\n        labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None\n\n        no_labels_features = [{k: v for k, v in feature.items() if k != label_name and k != \"gold\"} for feature in features]\n\n        batch = self.tokenizer.pad(\n            no_labels_features,\n            padding=self.padding,\n            max_length=self.max_length,\n            pad_to_multiple_of=self.pad_to_multiple_of,\n            return_tensors=\"pt\",\n        )\n\n        if labels is None:\n            return batch\n\n        sequence_length = batch[\"input_ids\"].shape[1]\n        padding_side = self.tokenizer.padding_side\n\n        def to_list(tensor_or_iterable):\n            if isinstance(tensor_or_iterable, torch.Tensor):\n                return tensor_or_iterable.tolist()\n            return list(tensor_or_iterable)\n\n        if padding_side == \"right\":\n            batch[label_name] = [\n                to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels\n            ]\n        else:\n            batch[label_name] = [\n                [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels\n            ]\n\n        batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)\n        if \"gold\" in features[0]:\n            batch[\"gold\"] = [feature[\"gold\"] for feature in features]\n        \n        return batch\n        \n\nclass SIGUSR1Callback(transformers.TrainerCallback):\n    \"\"\"\n    This callback is used to save the model when a SIGUSR1 signal is received\n    (SLURM stop signal or a keyboard interruption signal).\n    \"\"\"\n\n    def __init__(self) -> None:\n        super().__init__()\n        self.signal_received = False\n        signal.signal(signal.SIGUSR1, self.handle_signal)\n        signal.signal(signal.SIGINT, self.handle_signal)\n        logger.warn(\"Handler registered\")\n\n    def handle_signal(self, signum, frame):\n        self.signal_received = True\n        logger.warn(\"Signal received\")\n\n    def on_step_end(self, args, state, control, **kwargs):\n        if self.signal_received:\n            control.should_save = True\n            control.should_training_stop = True\n\n    def on_train_end(self, args, state, control, **kwargs):\n        if self.signal_received:\n            exit(0)\n\n\n@dataclass\nclass Prediction:\n    correct_candidate: Union[int, str]\n    predicted_candidate: Union[int, str]\n\n\n@contextlib.contextmanager\ndef count_time(name):\n    logger.info(\"%s...\" % name)\n    start_time = time.time()\n    try:\n        yield\n    finally:\n        logger.info(\"Done with %.2fs\" % (time.time() - start_time))\n\n\n@contextlib.contextmanager\ndef temp_seed(seed):\n    state = np.random.get_state()\n    np.random.seed(seed)\n    try:\n        yield\n    finally:\n        np.random.set_state(state)\n\n\nclass EnhancedJSONEncoder(json.JSONEncoder):\n    def default(self, o):\n        if is_dataclass(o):\n            return asdict(o)\n        return super().default(o)\n\n\ndef write_predictions_to_file(final_preds, output):\n    with open(output, \"w\") as f:\n        for pred in final_preds:\n            f.write(json.dumps(pred, cls=EnhancedJSONEncoder) + \"\\n\")\n\n\ndef write_metrics_to_file(metrics, output):\n    json.dump(metrics, open(output, \"w\"), cls=EnhancedJSONEncoder, indent=4)"
  },
  {
    "path": "requirements.txt",
    "content": "brotli==1.0.9\ncertifi==2024.7.4\ncharset-normalizer==3.3.2\nfilelock==3.13.1\nidna==3.7\nJinja2==3.1.4\nMarkupSafe==2.1.3\nnumpy==1.26.4\nPillow==10.4.0\nPySocks==1.7.1\nPyYAML==6.0.1\nrequests==2.32.3\nrich==14.0.0\nsetuptools==72.1.0\nurllib3==2.2.2\nwheel==0.43.0\naccelerate==1.6.0\ndatasets==3.5.1\naiohttp==3.10.3\naiosignal==1.3.1\nattrs==24.2.0\ndill==0.3.8\nfrozenlist==1.4.1\nfsspec==2024.5.0\nhuggingface-hub==0.24.5\njoblib==1.4.2\nmultidict==6.0.5\nmultiprocess==0.70.16\nopt-einsum==3.3.0\npackaging==24.1\npandas==2.2.2\npsutil==6.0.0\npyarrow==17.0.0\npyarrow-hotfix==0.6\npython-dateutil==2.9.0.post0\npytz==2024.1\nregex==2024.7.24\nscikit-learn==1.5.1\nscipy==1.14.0\nsix==1.16.0\nthreadpoolctl==3.5.0\ntokenizers==0.21.1\ntqdm==4.66.5\ntransformers==4.51.3\ntzdata==2024.1\nxxhash==3.4.1\nyarl==1.9.4\nnvidia-ml-py==12.570.86\ntrl==0.17.0\nsafetensors==0.5.2\n"
  },
  {
    "path": "script/add-copyright.py",
    "content": "import os\nimport datetime\nimport logging\n\ncurrent_year = datetime.datetime.now().year\nowner = \"liangyuwang\"\nlogging.basicConfig(filename='license_addition_errors.log', level=logging.ERROR)\n\ndef add_license_header(file_path, comment_style):\n    try:\n        with open(file_path, 'r+', encoding='utf-8') as file:\n            content = file.read()\n            license_snippet = \"Licensed under the Apache License, Version 2.0\"\n            if license_snippet not in content:\n                header = \"# Copyright Notice\\n\"\n                if comment_style == \"block\":\n                    header = f\"/* Copyright (c) {current_year} {owner}\\n * Licensed under the Apache License, Version 2.0\\n */\\n\\n\"\n                elif comment_style == \"line\":\n                    header = f\"# Copyright (c) {current_year} {owner}\\n# Licensed under the Apache License, Version 2.0\\n\\n\"\n                file.seek(0, 0)\n                file.write(header + content)\n    except FileNotFoundError:\n        logging.error(f\"File not found: {file_path}\")\n\nfile_map = {\n    '.cpp': 'block',\n    '.h': 'block',\n    '.cu': 'block',\n    '.py': 'line',\n    '.cmake': 'line'\n}\n\nfor root, dirs, files in os.walk(\".\"):\n    for file in files:\n        ext = os.path.splitext(file)[1]\n        if ext in file_map:\n            add_license_header(os.path.join(root, file), file_map[ext])\n        elif 'CMakeLists.txt' in file:\n            add_license_header(os.path.join(root, file), 'line')\n"
  },
  {
    "path": "script/clear-pycache.sh",
    "content": "find . | grep -E \"(/__pycache__$|\\.pyc$|\\.pyo$)\" | xargs rm -rf"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup, find_packages\n\nwith open('requirements.txt') as f:\n    requirements = f.read().splitlines()\n\nsetup(\n    name='zo2',\n    version='0.1.1',\n    author='liangyuwang',\n    author_email='liangyu.wang@kaust.edu.sa',\n    description='ZO2 (Zeroth-Order Offloading), a framework for full parameter fine-tuning 175B LLMs with 18GB GPU memory',\n    long_description=open('README.md').read(),\n    long_description_content_type='text/markdown',\n    packages=find_packages(),\n    install_requires=requirements,  # List of dependencies, read from requirements.txt\n    classifiers=[\n        'Programming Language :: Python :: 3.11',\n        'Programming Language :: Python :: 3.12',\n        'License :: OSI Approved :: Apache Software License',\n        'Operating System :: OS Independent',\n    ],\n    python_requires='>=3.11',\n    include_package_data=True,\n    zip_safe=False\n)\n"
  },
  {
    "path": "test/README.md",
    "content": "# Test\n\n- Important Notice: For fine-tuning the **OPT-175B** model, ensure that your system is equipped with at least `18GB of GPU memory` and `600GB of CPU memory`.\n\n## Example: MeZO-SGD on OPT Models\n\n```shell\n# compare memory\nbash test/mezo_sgd/hf_opt/test_memory_train.sh\n```\n\n```shell\n# compare throughput\nbash test/mezo_sgd/hf_opt/test_speed_train.sh\n```\n\n```shell\n# compare accuracy\nbash test/mezo_sgd/hf_opt/test_acc_train.sh\n```\n\n## Supported Tests\n\nIn progress..."
  },
  {
    "path": "test/mezo_sgd/hf_gpt/trainer.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n"
  },
  {
    "path": "test/mezo_sgd/hf_llama/trainer.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n"
  },
  {
    "path": "test/mezo_sgd/hf_opt/record_zo2_memory.sh",
    "content": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_30b\" \"opt_66b\" \"opt_175b\")\ntask_ids=(\"causalLM\")\n\n# ANSI color codes\nGREEN='\\033[0;32m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nfor model_name in \"${model_names[@]}\"\ndo\n    for task_id in \"${task_ids[@]}\"\n    do\n        echo \"Testing model_name: $model_name, task_id: $task_id\"\n        \n        CMD2=\"python test/mezo_sgd/hf_opt/test_memory.py --model_name $model_name --task $task_id --zo_method zo2 --max_steps 30\"\n\n        OUT2=\"/tmp/output2_${model_name}_${task_id}.txt\"\n\n        $CMD2 2>&1 | tee $OUT2\n\n        echo \"Recording Peak GPU and CPU Memory usage...\"\n        max_mem1=$(grep 'Peak GPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1)\n        max_mem2=$(grep 'Peak CPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1)\n\n        if [ -z \"$max_mem1\" ] || [ -z \"$max_mem2\" ]; then\n            echo \"Could not find memory usage data in the output.\"\n        else\n            echo -e \"Model: $model_name, Task: $task_id\"\n            echo -e \"ZO2 peak GPU memory: ${GREEN}$max_mem1 MB${NC}\"\n            echo -e \"ZO2 peak CPU memory: ${GREEN}$max_mem2 MB${NC}\"\n        fi\n\n        rm $OUT2\n    done\ndone"
  },
  {
    "path": "test/mezo_sgd/hf_opt/record_zo2_speed.sh",
    "content": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_30b\" \"opt_66b\" \"opt_175b\")\ntask_ids=(\"causalLM\")\n\n# ANSI color codes\nGREEN='\\033[0;32m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nfor model_name in \"${model_names[@]}\"\ndo\n    for task_id in \"${task_ids[@]}\"\n    do\n        echo \"Testing model_name: $model_name, task_id: $task_id\"\n        \n        CMD2=\"python test/mezo_sgd/hf_opt/test_speed.py --model_name $model_name --task $task_id --zo_method zo2 --max_steps 30\"\n\n        OUT2=\"/tmp/output2_${model_name}_${task_id}.txt\"\n\n        $CMD2 2>&1 | tee $OUT2\n\n        echo \"Recording throughput...\"\n        \n        # Count the total number of lines and determine the number of iteration lines\n        total_lines2=$(wc -l < $OUT2)\n        iter_lines2=$(grep -c 'Time cost after iteration' $OUT2)\n\n        # Calculate the starting line for the last 50% of iterations\n        start_line2=$(($total_lines2 - $iter_lines2 + $(($iter_lines2 / 2 + 1))))\n\n        # Calculate average tokens per second for the last 50% of the iterations\n        avg_tok_s2=$(tail -n +$start_line2 $OUT2 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}')\n\n        echo -e \"Model: $model_name, Task: $task_id\"\n        echo -e \"ZO2 average throughput (last 50% iterations): ${GREEN}$avg_tok_s2 tok/s${NC}\"\n        \n        rm $OUT2\n    done\ndone"
  },
  {
    "path": "test/mezo_sgd/hf_opt/test_acc.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport sys\nsys.path.append(\"../zo2\")\n\nimport torch\nfrom tqdm import tqdm\n\nfrom zo2.config.mezo_sgd import MeZOSGDConfig\nfrom zo2.model.huggingface.opt.mezo_sgd import zo, zo2\nfrom zo2.utils.utils import seed_everything\nfrom utils import (\n    OPTConfigs,\n    prepare_data_for_causalLM, \n    prepare_data_for_sequence_classification,\n    prepare_data_for_question_answering,\n    model_size, \n    get_args\n)\n\ndef train_mezo_sgd_causalLM(model_config, zo_config, device='cuda'):\n    input_ids, labels = prepare_data_for_causalLM(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo.OPTForCausalLM(model_config).to(device)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        model.zo_train()\n        loss = model(input_ids=input_ids, labels=labels)\n        res = \"Iteration {}, loss: {}, projected grad: {}\"\n        tqdm.write(res.format(i, loss, model.opt.projected_grad))\n\ndef train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'):\n    input_ids, labels = prepare_data_for_causalLM(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo2.OPTForCausalLM(model_config)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        model.zo_train()\n        loss = model(input_ids=input_ids, labels=labels)\n        res = \"Iteration {}, loss: {}, projected grad: {}\"\n        tqdm.write(res.format(i, loss, model.opt.projected_grad))\n\ndef eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda'):\n    input_ids, labels = prepare_data_for_causalLM(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo.OPTForCausalLM(model_config).to(device)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        model.zo_eval()\n        loss = model(input_ids=input_ids, labels=labels)[\"loss\"]\n        res = \"Iteration {}, loss: {}\"\n        tqdm.write(res.format(i, loss))\n\ndef eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'):\n    input_ids, labels = prepare_data_for_causalLM(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo2.OPTForCausalLM(model_config)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        model.zo_eval()\n        loss = model(input_ids=input_ids, labels=labels)[\"loss\"]\n        res = \"Iteration {}, loss: {}\"\n        tqdm.write(res.format(i, loss))\n\n\ndef train_mezo_sgd_sequence_classification(model_config, zo_config, device='cuda'):\n    input_ids, labels = prepare_data_for_sequence_classification(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo.OPTForSequenceClassification(model_config).to(device)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        model.zo_train()\n        loss = model(input_ids=input_ids, labels=labels)\n        res = \"Iteration {}, loss: {}, projected grad: {}\"\n        tqdm.write(res.format(i, loss, model.opt.projected_grad))\n\ndef train_mezo2_sgd_sequence_classification(model_config, zo_config, device='cuda'):\n    input_ids, labels = prepare_data_for_sequence_classification(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo2.OPTForSequenceClassification(model_config).to(device)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        model.zo_train()\n        loss = model(input_ids=input_ids, labels=labels)\n        res = \"Iteration {}, loss: {}, projected grad: {}\"\n        tqdm.write(res.format(i, loss, model.opt.projected_grad))\n\ndef eval_mezo_sgd_sequence_classification(model_config, zo_config, device='cuda'):\n    input_ids, labels = prepare_data_for_sequence_classification(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo.OPTForSequenceClassification(model_config).to(device)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        model.zo_eval()\n        loss = model(input_ids=input_ids, labels=labels)[\"loss\"]\n        res = \"Iteration {}, loss: {}\"\n        tqdm.write(res.format(i, loss))\n\ndef eval_mezo2_sgd_sequence_classification(model_config, zo_config, device='cuda'):\n    input_ids, labels = prepare_data_for_sequence_classification(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo2.OPTForSequenceClassification(model_config).to(device)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        model.zo_eval()\n        loss = model(input_ids=input_ids, labels=labels)[\"loss\"]\n        res = \"Iteration {}, loss: {}\"\n        tqdm.write(res.format(i, loss))\n\n\ndef train_mezo_sgd_question_answering(model_config, zo_config, device='cuda'):\n    input_ids, start_positions, end_positions = prepare_data_for_question_answering(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo.OPTForQuestionAnswering(model_config).to(\"cuda\")\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        model.zo_train()\n        loss = model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions)\n        res = \"Iteration {}, loss: {}, projected grad: {}\"\n        tqdm.write(res.format(i, loss, model.opt.projected_grad))\n\ndef train_mezo2_sgd_question_answering(model_config, zo_config, device='cuda'):\n    input_ids, start_positions, end_positions = prepare_data_for_question_answering(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo2.OPTForQuestionAnswering(model_config).to(\"cuda\")\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        model.zo_train()\n        loss = model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions)\n        res = \"Iteration {}, loss: {}, projected grad: {}\"\n        tqdm.write(res.format(i, loss, model.opt.projected_grad))\n\ndef eval_mezo_sgd_question_answering(model_config, zo_config, device='cuda'):\n    input_ids, start_positions, end_positions = prepare_data_for_question_answering(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo.OPTForQuestionAnswering(model_config).to(\"cuda\")\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        model.zo_eval()\n        loss = model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions)[\"loss\"]\n        res = \"Iteration {}, loss: {}\"\n        tqdm.write(res.format(i, loss))\n\ndef eval_mezo2_sgd_question_answering(model_config, zo_config, device='cuda'):\n    input_ids, start_positions, end_positions = prepare_data_for_question_answering(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo2.OPTForQuestionAnswering(model_config).to(\"cuda\")\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        model.zo_eval()\n        loss = model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions)[\"loss\"]\n        res = \"Iteration {}, loss: {}\"\n        tqdm.write(res.format(i, loss))\n\n\ndef test_mezo_sgd_causalLM_training():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    train_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo2_sgd_causalLM_training():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    zo_cfg.overlap = args.overlap==\"all\"\n    train_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo_sgd_causalLM_eval():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    eval_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo2_sgd_causalLM_eval():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    zo_cfg.overlap = args.overlap==\"all\"\n    eval_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device)\n\n\ndef test_mezo_sgd_sequence_classification_training():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    train_mezo_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo2_sgd_sequence_classification_training():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    zo_cfg.overlap = args.overlap==\"all\"\n    train_mezo2_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo_sgd_sequence_classification_eval():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    eval_mezo_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo2_sgd_sequence_classification_eval():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    zo_cfg.overlap = args.overlap==\"all\"\n    eval_mezo2_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device)\n\n\ndef test_mezo_sgd_question_answering_training():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    train_mezo_sgd_question_answering(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo2_sgd_question_answering_training():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    zo_cfg.overlap = args.overlap==\"all\"\n    train_mezo2_sgd_question_answering(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo_sgd_question_answering_eval():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    eval_mezo_sgd_question_answering(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo2_sgd_question_answering_eval():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    zo_cfg.overlap = args.overlap==\"all\"\n    eval_mezo2_sgd_question_answering(model_config, zo_cfg, device=args.working_device)\n\n\nif __name__==\"__main__\":\n    args = get_args()\n    original_dtype = torch.get_default_dtype()\n    if args.zo_method == \"zo\":\n        if args.task == \"causalLM\":\n            if args.eval:\n                test_mezo_sgd_causalLM_eval()\n            else:\n                test_mezo_sgd_causalLM_training()\n        elif args.task == \"sequence_classification\":\n            if args.eval:\n                test_mezo_sgd_sequence_classification_eval()\n            else:\n                test_mezo_sgd_sequence_classification_training()\n        elif args.task == \"question_answering\":\n            if args.eval:\n                test_mezo_sgd_question_answering_eval()\n            else:\n                test_mezo_sgd_question_answering_training()\n        else:\n            raise NotImplementedError(f\"Task {args.task} is unsupported.\")\n    elif args.zo_method == \"zo2\":\n        if args.task == \"causalLM\":\n            if args.eval:\n                test_mezo2_sgd_causalLM_eval()\n            else:\n                test_mezo2_sgd_causalLM_training()\n        elif args.task == \"sequence_classification\":\n            if args.eval:\n                test_mezo2_sgd_sequence_classification_eval()\n            else:\n                test_mezo2_sgd_sequence_classification_training()\n        elif args.task == \"question_answering\":\n            if args.eval:\n                test_mezo2_sgd_question_answering_eval()\n            else:\n                test_mezo2_sgd_question_answering_training()\n        else:\n            raise NotImplementedError(f\"Task {args.task} is unsupported.\")\n    else:\n        raise NotImplementedError"
  },
  {
    "path": "test/mezo_sgd/hf_opt/test_acc_eval.sh",
    "content": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_30b\" \"opt_66b\" \"opt_175b\")\ntask_ids=(\"causalLM\" \"sequence_classification\" \"question_answering\")\n\n# ANSI color codes\nGREEN='\\033[0;32m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nfor model_name in \"${model_names[@]}\"\ndo\n    for task_id in \"${task_ids[@]}\"\n    do\n        echo \"Testing model_name: $model_name, task_id: $task_id\"\n        \n        if [ \"$task_id\" == \"causalLM\" ]; then\n            lr=1e-4\n        else\n            lr=1e-7\n        fi\n\n        CMD1=\"python test/mezo_sgd/hf_opt/test_acc.py --model_name $model_name --task $task_id --zo_method zo --lr $lr --eval --max_steps 30\"\n        CMD2=\"python test/mezo_sgd/hf_opt/test_acc.py --model_name $model_name --task $task_id --zo_method zo2 --lr $lr --eval --max_steps 30\"\n\n        OUT1=\"/tmp/output1_${model_name}_${task_id}.txt\"\n        OUT2=\"/tmp/output2_${model_name}_${task_id}.txt\"\n\n        $CMD1 2>&1 | tee $OUT1\n        $CMD2 2>&1 | tee $OUT2\n\n        echo \"Comparing outputs...\"\n        echo -e \"Model: $model_name, Task: $task_id\"\n        paste <(grep 'Iteration' $OUT1) <(grep 'Iteration' $OUT2) | awk -v green=\"$GREEN\" -v red=\"$RED\" -v nc=\"$NC\" '{\n            split($4, loss1, \",\");\n            split($8, loss2, \",\");\n            diff_loss = loss1[1] - loss2[1];\n            if (loss1[1] == loss2[1])\n                printf \"Iteration %s: %s✓ loss match.%s\\n\", $2, green, nc;\n            else\n                printf \"Iteration %s: %s✗ Mismatch! ZO (loss): (%s), ZO2 (loss): (%s) Loss diff: %.6f%s\\n\", $2, red, loss1[1], loss2[1], diff_loss, nc;\n        }'\n\n        rm $OUT1 $OUT2\n    done\ndone"
  },
  {
    "path": "test/mezo_sgd/hf_opt/test_acc_train.sh",
    "content": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_30b\" \"opt_66b\" \"opt_175b\")\ntask_ids=(\"causalLM\" \"sequence_classification\" \"question_answering\")\n\n# ANSI color codes\nGREEN='\\033[0;32m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nfor model_name in \"${model_names[@]}\"\ndo\n    for task_id in \"${task_ids[@]}\"\n    do\n        echo \"Testing model_name: $model_name, task_id: $task_id\"\n        \n        if [ \"$task_id\" == \"causalLM\" ]; then\n            lr=1e-4\n        else\n            lr=1e-7\n        fi\n\n        CMD1=\"python test/mezo_sgd/hf_opt/test_acc.py --model_name $model_name --task $task_id --zo_method zo --lr $lr\"\n        CMD2=\"python test/mezo_sgd/hf_opt/test_acc.py --model_name $model_name --task $task_id --zo_method zo2 --lr $lr\"\n\n        OUT1=\"/tmp/output1_${model_name}_${task_id}.txt\"\n        OUT2=\"/tmp/output2_${model_name}_${task_id}.txt\"\n\n        $CMD1 2>&1 | tee $OUT1\n        $CMD2 2>&1 | tee $OUT2\n\n        echo \"Comparing outputs...\"\n        echo -e \"Model: $model_name, Task: $task_id\"\n        paste <(grep 'Iteration' $OUT1) <(grep 'Iteration' $OUT2) | awk -v green=\"$GREEN\" -v red=\"$RED\" -v nc=\"$NC\" '{\n            split($4, loss1, \",\");\n            split($7, proj1, \",\");\n            split($11, loss2, \",\");\n            split($14, proj2, \",\");\n            diff_loss = loss1[1] - loss2[1];\n            diff_proj = proj1[1] - proj2[1];\n            if (loss1[1] == loss2[1] && proj1[1] == proj2[1])\n                printf \"Iteration %d: %s✓ loss and projected grad match.%s\\n\", $2, green, nc;\n            else\n                printf \"Iteration %d: %s✗ Mismatch! ZO (loss, grad): (%s, %s), ZO2 (loss, grad): (%s, %s)\\n \\tLoss diff: %.6f, Proj grad diff: %.6f%s\\n\", $2, red, loss1[1], proj1[1], loss2[1], proj2[1], diff_loss, diff_proj, nc;\n        }'\n\n        rm $OUT1 $OUT2\n    done\ndone"
  },
  {
    "path": "test/mezo_sgd/hf_opt/test_memory.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport sys\nsys.path.append(\"../zo2\")\n\nimport torch\nfrom tqdm import tqdm\n\nfrom zo2.config.mezo_sgd import MeZOSGDConfig\nfrom zo2.model.huggingface.opt.mezo_sgd import zo, zo2\nfrom zo2.utils.utils import seed_everything\nfrom utils import (\n    OPTConfigs,\n    prepare_data_for_causalLM, \n    prepare_data_for_sequence_classification,\n    prepare_data_for_question_answering,\n    model_size, \n    get_args,\n    check_peak_gpu_memory_usage,\n    reset_peak_cpu_memory_usage, \n    check_and_update_peak_cpu_memory_usage\n)\n\ndef train_mezo_sgd_causalLM(model_config, zo_config, device='cuda:0'):\n    input_ids, labels = prepare_data_for_causalLM(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo.OPTForCausalLM(model_config).to(device)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    torch.cuda.reset_peak_memory_stats()\n    reset_peak_cpu_memory_usage()\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        model(input_ids=input_ids, labels=labels)\n        check_peak_gpu_memory_usage(i, int(device[-1]), True)\n        check_and_update_peak_cpu_memory_usage(i, True)\n\ndef train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda:0'):\n    input_ids, labels = prepare_data_for_causalLM(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo2.OPTForCausalLM(model_config)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    torch.cuda.reset_peak_memory_stats()\n    reset_peak_cpu_memory_usage()\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        model(input_ids=input_ids, labels=labels)\n        check_peak_gpu_memory_usage(i, int(device[-1]), True)\n        check_and_update_peak_cpu_memory_usage(i, True)\n\ndef eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda:0'):\n    input_ids, labels = prepare_data_for_causalLM(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo.OPTForCausalLM(model_config).to(device)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    torch.cuda.reset_peak_memory_stats()\n    reset_peak_cpu_memory_usage()\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        model(input_ids=input_ids, labels=labels)\n        check_peak_gpu_memory_usage(i, int(device[-1]), True)\n        check_and_update_peak_cpu_memory_usage(i, True)\n\ndef eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda:0'):\n    input_ids, labels = prepare_data_for_causalLM(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo2.OPTForCausalLM(model_config)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    torch.cuda.reset_peak_memory_stats()\n    reset_peak_cpu_memory_usage()\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        model(input_ids=input_ids, labels=labels)\n        check_peak_gpu_memory_usage(i, int(device[-1]), True)\n        check_and_update_peak_cpu_memory_usage(i, True)\n\n\ndef train_mezo_sgd_sequence_classification(model_config, zo_config, device='cuda:0'):\n    input_ids, labels = prepare_data_for_sequence_classification(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo.OPTForSequenceClassification(model_config).to(device)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    torch.cuda.reset_peak_memory_stats()\n    reset_peak_cpu_memory_usage()\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        model(input_ids=input_ids, labels=labels)\n        check_peak_gpu_memory_usage(i, int(device[-1]), True)\n        check_and_update_peak_cpu_memory_usage(i, True)\n\ndef train_mezo2_sgd_sequence_classification(model_config, zo_config, device='cuda:0'):\n    input_ids, labels = prepare_data_for_sequence_classification(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo2.OPTForSequenceClassification(model_config).to(device)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    torch.cuda.reset_peak_memory_stats()\n    reset_peak_cpu_memory_usage()\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        model(input_ids=input_ids, labels=labels)\n        check_peak_gpu_memory_usage(i, int(device[-1]), True)\n        check_and_update_peak_cpu_memory_usage(i, True)\n\ndef eval_mezo_sgd_sequence_classification(model_config, zo_config, device='cuda:0'):\n    input_ids, labels = prepare_data_for_sequence_classification(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo.OPTForSequenceClassification(model_config).to(device)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    torch.cuda.reset_peak_memory_stats()\n    reset_peak_cpu_memory_usage()\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        model(input_ids=input_ids, labels=labels)\n        check_peak_gpu_memory_usage(i, int(device[-1]), True)\n        check_and_update_peak_cpu_memory_usage(i, True)\n\ndef eval_mezo2_sgd_sequence_classification(model_config, zo_config, device='cuda:0'):\n    input_ids, labels = prepare_data_for_sequence_classification(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo2.OPTForSequenceClassification(model_config).to(device)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    torch.cuda.reset_peak_memory_stats()\n    reset_peak_cpu_memory_usage()\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        model(input_ids=input_ids, labels=labels)\n        check_peak_gpu_memory_usage(i, int(device[-1]), True)\n        check_and_update_peak_cpu_memory_usage(i, True)\n\n\ndef train_mezo_sgd_question_answering(model_config, zo_config, device='cuda:0'):\n    input_ids, start_positions, end_positions = prepare_data_for_question_answering(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo.OPTForQuestionAnswering(model_config).to(\"cuda\")\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    torch.cuda.reset_peak_memory_stats()\n    reset_peak_cpu_memory_usage()\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions)\n        check_peak_gpu_memory_usage(i, int(device[-1]), True)\n        check_and_update_peak_cpu_memory_usage(i, True)\n\ndef train_mezo2_sgd_question_answering(model_config, zo_config, device='cuda:0'):\n    input_ids, start_positions, end_positions = prepare_data_for_question_answering(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo2.OPTForQuestionAnswering(model_config).to(\"cuda\")\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    torch.cuda.reset_peak_memory_stats()\n    reset_peak_cpu_memory_usage()\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions)\n        check_peak_gpu_memory_usage(i, int(device[-1]), True)\n        check_and_update_peak_cpu_memory_usage(i, True)\n\ndef eval_mezo_sgd_question_answering(model_config, zo_config, device='cuda:0'):\n    input_ids, start_positions, end_positions = prepare_data_for_question_answering(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo.OPTForQuestionAnswering(model_config).to(\"cuda\")\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    torch.cuda.reset_peak_memory_stats()\n    reset_peak_cpu_memory_usage()\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions)\n        check_peak_gpu_memory_usage(i, int(device[-1]), True)\n        check_and_update_peak_cpu_memory_usage(i, True)\n\ndef eval_mezo2_sgd_question_answering(model_config, zo_config, device='cuda:0'):\n    input_ids, start_positions, end_positions = prepare_data_for_question_answering(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo2.OPTForQuestionAnswering(model_config).to(\"cuda\")\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    torch.cuda.reset_peak_memory_stats()\n    reset_peak_cpu_memory_usage()\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        model(input_ids=input_ids, start_positions=start_positions, end_positions=end_positions)\n        check_peak_gpu_memory_usage(i, int(device[-1]), True)\n        check_and_update_peak_cpu_memory_usage(i, True)\n\n\ndef test_mezo_sgd_causalLM_training():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    train_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo2_sgd_causalLM_training():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    train_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo_sgd_causalLM_eval():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    eval_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo2_sgd_causalLM_eval():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    eval_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device)\n\n\ndef test_mezo_sgd_sequence_classification_training():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    train_mezo_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo2_sgd_sequence_classification_training():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    train_mezo2_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo_sgd_sequence_classification_eval():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    eval_mezo_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo2_sgd_sequence_classification_eval():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    eval_mezo2_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device)\n\n\ndef test_mezo_sgd_question_answering_training():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    train_mezo_sgd_question_answering(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo2_sgd_question_answering_training():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    train_mezo2_sgd_question_answering(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo_sgd_question_answering_eval():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    eval_mezo_sgd_question_answering(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo2_sgd_question_answering_eval():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    eval_mezo2_sgd_question_answering(model_config, zo_cfg, device=args.working_device)\n\n\nif __name__==\"__main__\":\n    args = get_args()\n    original_dtype = torch.get_default_dtype()\n    if args.zo_method == \"zo\":\n        if args.task == \"causalLM\":\n            if args.eval:\n                test_mezo_sgd_causalLM_eval()\n            else:\n                test_mezo_sgd_causalLM_training()\n        elif args.task == \"sequence_classification\":\n            if args.eval:\n                test_mezo_sgd_sequence_classification_eval()\n            else:\n                test_mezo_sgd_sequence_classification_training()\n        elif args.task == \"question_answering\":\n            if args.eval:\n                test_mezo_sgd_question_answering_eval()\n            else:\n                test_mezo_sgd_question_answering_training()\n        else:\n            raise NotImplementedError(f\"Task {args.task} is unsupported.\")\n    elif args.zo_method == \"zo2\":\n        if args.task == \"causalLM\":\n            if args.eval:\n                test_mezo2_sgd_causalLM_eval()\n            else:\n                test_mezo2_sgd_causalLM_training()\n        elif args.task == \"sequence_classification\":\n            if args.eval:\n                test_mezo2_sgd_sequence_classification_eval()\n            else:\n                test_mezo2_sgd_sequence_classification_training()\n        elif args.task == \"question_answering\":\n            if args.eval:\n                test_mezo2_sgd_question_answering_eval()\n            else:\n                test_mezo2_sgd_question_answering_training()\n        else:\n            raise NotImplementedError(f\"Task {args.task} is unsupported.\")\n    else:\n        raise NotImplementedError"
  },
  {
    "path": "test/mezo_sgd/hf_opt/test_memory_eval.sh",
    "content": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_30b\" \"opt_66b\" \"opt_175b\")\ntask_ids=(\"causalLM\")\n\n# ANSI color codes\nGREEN='\\033[0;32m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nfor model_name in \"${model_names[@]}\"\ndo\n    for task_id in \"${task_ids[@]}\"\n    do\n        echo \"Testing model_name: $model_name, task_id: $task_id\"\n        \n        CMD1=\"python test/mezo_sgd/hf_opt/test_memory.py --model_name $model_name --task $task_id --zo_method zo --max_steps 30 --eval\"\n        CMD2=\"python test/mezo_sgd/hf_opt/test_memory.py --model_name $model_name --task $task_id --zo_method zo2 --max_steps 30 --eval\"\n\n        OUT1=\"/tmp/output1_${model_name}_${task_id}.txt\"\n        OUT2=\"/tmp/output2_${model_name}_${task_id}.txt\"\n\n        $CMD1 2>&1 | tee $OUT1\n        $CMD2 2>&1 | tee $OUT2\n\n        echo \"Analyzing Peak GPU Memory usage...\"\n        max_mem1=$(grep 'Peak GPU Memory' $OUT1 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1)\n        max_mem2=$(grep 'Peak GPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1)\n\n        if [ -z \"$max_mem1\" ] || [ -z \"$max_mem2\" ]; then\n            echo \"Could not find memory usage data in the output.\"\n        else\n            ratio=$(echo \"scale=2; $max_mem2 / $max_mem1 * 100\" | bc)\n            echo -e \"Model: $model_name, Task: $task_id\"\n            echo -e \"ZO peak GPU memory: ${GREEN}$max_mem1 MB${NC}\"\n            echo -e \"ZO2 peak GPU memory: ${GREEN}$max_mem2 MB${NC}\"\n            echo -e \"Memory usage ratio of ZO2 to ZO: ${GREEN}$ratio%${NC}\"\n        fi\n\n        rm $OUT1 $OUT2\n    done\ndone"
  },
  {
    "path": "test/mezo_sgd/hf_opt/test_memory_train.sh",
    "content": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_30b\" \"opt_66b\" \"opt_175b\")\ntask_ids=(\"causalLM\")\n\n# ANSI color codes\nGREEN='\\033[0;32m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nfor model_name in \"${model_names[@]}\"\ndo\n    for task_id in \"${task_ids[@]}\"\n    do\n        echo \"Testing model_name: $model_name, task_id: $task_id\"\n        \n        CMD1=\"python test/mezo_sgd/hf_opt/test_memory.py --model_name $model_name --task $task_id --zo_method zo --max_steps 30\"\n        CMD2=\"python test/mezo_sgd/hf_opt/test_memory.py --model_name $model_name --task $task_id --zo_method zo2 --max_steps 30\"\n\n        OUT1=\"/tmp/output1_${model_name}_${task_id}.txt\"\n        OUT2=\"/tmp/output2_${model_name}_${task_id}.txt\"\n\n        $CMD1 2>&1 | tee $OUT1\n        $CMD2 2>&1 | tee $OUT2\n\n        echo \"Analyzing Peak GPU Memory usage...\"\n        max_mem1=$(grep 'Peak GPU Memory' $OUT1 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1)\n        max_mem2=$(grep 'Peak GPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1)\n\n        if [ -z \"$max_mem1\" ] || [ -z \"$max_mem2\" ]; then\n            echo \"Could not find memory usage data in the output.\"\n        else\n            ratio=$(echo \"scale=2; $max_mem2 / $max_mem1 * 100\" | bc)\n            echo -e \"Model: $model_name, Task: $task_id\"\n            echo -e \"ZO peak GPU memory: ${GREEN}$max_mem1 MB${NC}\"\n            echo -e \"ZO2 peak GPU memory: ${GREEN}$max_mem2 MB${NC}\"\n            echo -e \"Memory usage ratio of ZO2 to ZO: ${GREEN}$ratio%${NC}\"\n        fi\n\n        rm $OUT1 $OUT2\n    done\ndone"
  },
  {
    "path": "test/mezo_sgd/hf_opt/test_scheduler_acc_eval.sh",
    "content": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_30b\" \"opt_66b\" \"opt_175b\")\ntask_ids=(\"causalLM\" \"sequence_classification\" \"question_answering\")\n\n# ANSI color codes\nGREEN='\\033[0;32m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nfor model_name in \"${model_names[@]}\"\ndo\n    for task_id in \"${task_ids[@]}\"\n    do\n        echo \"Testing model_name: $model_name, task_id: $task_id\"\n        \n        if [ \"$task_id\" == \"causalLM\" ]; then\n            lr=1e-4\n        else\n            lr=1e-7\n        fi\n\n        CMD1=\"python test/mezo_sgd/hf_opt/test_acc.py --model_name $model_name --task $task_id --zo_method zo2 --lr $lr --eval --overlap no --max_steps 30\"\n        CMD2=\"python test/mezo_sgd/hf_opt/test_acc.py --model_name $model_name --task $task_id --zo_method zo2 --lr $lr --eval --max_steps 30\"\n\n        OUT1=\"/tmp/output1_${model_name}_${task_id}.txt\"\n        OUT2=\"/tmp/output2_${model_name}_${task_id}.txt\"\n\n        $CMD1 2>&1 | tee $OUT1\n        $CMD2 2>&1 | tee $OUT2\n\n        echo \"Comparing outputs...\"\n        echo -e \"Model: $model_name, Task: $task_id\"\n        paste <(grep 'Iteration' $OUT1) <(grep 'Iteration' $OUT2) | awk -v green=\"$GREEN\" -v red=\"$RED\" -v nc=\"$NC\" '{\n            split($4, loss1, \",\");\n            split($8, loss2, \",\");\n            diff_loss = loss1[1] - loss2[1];\n            if (loss1[1] == loss2[1])\n                printf \"Iteration %s: %s✓ loss match.%s\\n\", $2, green, nc;\n            else\n                printf \"Iteration %s: %s✗ Mismatch! ZO (loss): (%s), ZO2 (loss): (%s) Loss diff: %.6f%s\\n\", $2, red, loss1[1], loss2[1], diff_loss, nc;\n        }'\n\n        rm $OUT1 $OUT2\n    done\ndone"
  },
  {
    "path": "test/mezo_sgd/hf_opt/test_scheduler_acc_train.sh",
    "content": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_30b\" \"opt_66b\" \"opt_175b\")\ntask_ids=(\"causalLM\" \"sequence_classification\" \"question_answering\")\n\n# ANSI color codes\nGREEN='\\033[0;32m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nfor model_name in \"${model_names[@]}\"\ndo\n    for task_id in \"${task_ids[@]}\"\n    do\n        echo \"Testing model_name: $model_name, task_id: $task_id\"\n        \n        if [ \"$task_id\" == \"causalLM\" ]; then\n            lr=1e-4\n        else\n            lr=1e-7\n        fi\n\n        CMD1=\"python test/mezo_sgd/hf_opt/test_acc.py --model_name $model_name --task $task_id --zo_method zo2 --lr $lr --overlap no --max_steps 30\"\n        CMD2=\"python test/mezo_sgd/hf_opt/test_acc.py --model_name $model_name --task $task_id --zo_method zo2 --lr $lr --max_steps 30\"\n\n        OUT1=\"/tmp/output1_${model_name}_${task_id}.txt\"\n        OUT2=\"/tmp/output2_${model_name}_${task_id}.txt\"\n\n        $CMD1 2>&1 | tee $OUT1\n        $CMD2 2>&1 | tee $OUT2\n\n        echo \"Comparing outputs...\"\n        echo -e \"Model: $model_name, Task: $task_id\"\n        paste <(grep 'Iteration' $OUT1) <(grep 'Iteration' $OUT2) | awk -v green=\"$GREEN\" -v red=\"$RED\" -v nc=\"$NC\" '{\n            split($4, loss1, \",\");\n            split($7, proj1, \",\");\n            split($11, loss2, \",\");\n            split($14, proj2, \",\");\n            diff_loss = loss1[1] - loss2[1];\n            diff_proj = proj1[1] - proj2[1];\n            if (loss1[1] == loss2[1] && proj1[1] == proj2[1])\n                printf \"Iteration %d: %s✓ loss and projected grad match.%s\\n\", $2, green, nc;\n            else\n                printf \"Iteration %d: %s✗ Mismatch! Non-Overlap (loss, grad): (%s, %s), Overlap (loss, grad): (%s, %s)\\n \\tLoss diff: %.6f, Proj grad diff: %.6f%s\\n\", $2, red, loss1[1], proj1[1], loss2[1], proj2[1], diff_loss, diff_proj, nc;\n        }'\n\n        rm $OUT1 $OUT2\n    done\ndone"
  },
  {
    "path": "test/mezo_sgd/hf_opt/test_speed.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport sys\nsys.path.append(\"../zo2\")\n\nimport torch\nfrom tqdm import tqdm\n\nfrom zo2.config.mezo_sgd import MeZOSGDConfig\nfrom zo2.model.huggingface.opt.mezo_sgd import zo, zo2\nfrom zo2.utils.utils import seed_everything\nfrom utils import (\n    OPTConfigs,\n    prepare_data_for_causalLM, \n    prepare_data_for_sequence_classification,\n    prepare_data_for_question_answering,\n    model_size, \n    get_args,\n    check_throughput\n)\n\ndef train_mezo_sgd_causalLM(model_config, zo_config, device='cuda'):\n    input_ids, labels = prepare_data_for_causalLM(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo.OPTForCausalLM(model_config).to(device)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True)\n\ndef train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'):\n    input_ids, labels = prepare_data_for_causalLM(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo2.OPTForCausalLM(model_config)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True)\n\ndef eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda'):\n    input_ids, labels = prepare_data_for_causalLM(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo.OPTForCausalLM(model_config).to(device)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True)\n\ndef eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'):\n    input_ids, labels = prepare_data_for_causalLM(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo2.OPTForCausalLM(model_config)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True)\n\n\ndef train_mezo_sgd_sequence_classification(model_config, zo_config, device='cuda'):\n    input_ids, labels = prepare_data_for_sequence_classification(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo.OPTForSequenceClassification(model_config).to(device)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True)\n\ndef train_mezo2_sgd_sequence_classification(model_config, zo_config, device='cuda'):\n    input_ids, labels = prepare_data_for_sequence_classification(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo2.OPTForSequenceClassification(model_config).to(device)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True)\n\ndef eval_mezo_sgd_sequence_classification(model_config, zo_config, device='cuda'):\n    input_ids, labels = prepare_data_for_sequence_classification(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo.OPTForSequenceClassification(model_config).to(device)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True)\n\ndef eval_mezo2_sgd_sequence_classification(model_config, zo_config, device='cuda'):\n    input_ids, labels = prepare_data_for_sequence_classification(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo2.OPTForSequenceClassification(model_config).to(device)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True)\n\n\ndef train_mezo_sgd_question_answering(model_config, zo_config, device='cuda'):\n    input_ids, start_positions, end_positions = prepare_data_for_question_answering(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo.OPTForQuestionAnswering(model_config).to(\"cuda\")\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, start_positions=start_positions, end_positions=end_positions, use_tqdm=True)\n\ndef train_mezo2_sgd_question_answering(model_config, zo_config, device='cuda'):\n    input_ids, start_positions, end_positions = prepare_data_for_question_answering(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo2.OPTForQuestionAnswering(model_config).to(\"cuda\")\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, start_positions=start_positions, end_positions=end_positions, use_tqdm=True)\n\ndef eval_mezo_sgd_question_answering(model_config, zo_config, device='cuda'):\n    input_ids, start_positions, end_positions = prepare_data_for_question_answering(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo.OPTForQuestionAnswering(model_config).to(\"cuda\")\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, start_positions=start_positions, end_positions=end_positions, use_tqdm=True)\n\ndef eval_mezo2_sgd_question_answering(model_config, zo_config, device='cuda'):\n    input_ids, start_positions, end_positions = prepare_data_for_question_answering(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo2.OPTForQuestionAnswering(model_config).to(\"cuda\")\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, start_positions=start_positions, end_positions=end_positions, use_tqdm=True)\n\n\ndef test_mezo_sgd_causalLM_training():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    train_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo2_sgd_causalLM_training():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    train_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo_sgd_causalLM_eval():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    eval_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo2_sgd_causalLM_eval():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    eval_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device)\n\n\ndef test_mezo_sgd_sequence_classification_training():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    train_mezo_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo2_sgd_sequence_classification_training():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    train_mezo2_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo_sgd_sequence_classification_eval():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    eval_mezo_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo2_sgd_sequence_classification_eval():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    eval_mezo2_sgd_sequence_classification(model_config, zo_cfg, device=args.working_device)\n\n\ndef test_mezo_sgd_question_answering_training():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    train_mezo_sgd_question_answering(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo2_sgd_question_answering_training():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    train_mezo2_sgd_question_answering(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo_sgd_question_answering_eval():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    eval_mezo_sgd_question_answering(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo2_sgd_question_answering_eval():\n    seed_everything(args.seed)\n    model_configs = OPTConfigs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    eval_mezo2_sgd_question_answering(model_config, zo_cfg, device=args.working_device)\n\n\nif __name__==\"__main__\":\n    args = get_args()\n    original_dtype = torch.get_default_dtype()\n    if args.zo_method == \"zo\":\n        if args.task == \"causalLM\":\n            if args.eval:\n                test_mezo_sgd_causalLM_eval()\n            else:\n                test_mezo_sgd_causalLM_training()\n        elif args.task == \"sequence_classification\":\n            if args.eval:\n                test_mezo_sgd_sequence_classification_eval()\n            else:\n                test_mezo_sgd_sequence_classification_training()\n        elif args.task == \"question_answering\":\n            if args.eval:\n                test_mezo_sgd_question_answering_eval()\n            else:\n                test_mezo_sgd_question_answering_training()\n        else:\n            raise NotImplementedError(f\"Task {args.task} is unsupported.\")\n    elif args.zo_method == \"zo2\":\n        if args.task == \"causalLM\":\n            if args.eval:\n                test_mezo2_sgd_causalLM_eval()\n            else:\n                test_mezo2_sgd_causalLM_training()\n        elif args.task == \"sequence_classification\":\n            if args.eval:\n                test_mezo2_sgd_sequence_classification_eval()\n            else:\n                test_mezo2_sgd_sequence_classification_training()\n        elif args.task == \"question_answering\":\n            if args.eval:\n                test_mezo2_sgd_question_answering_eval()\n            else:\n                test_mezo2_sgd_question_answering_training()\n        else:\n            raise NotImplementedError(f\"Task {args.task} is unsupported.\")\n    else:\n        raise NotImplementedError"
  },
  {
    "path": "test/mezo_sgd/hf_opt/test_speed_eval.sh",
    "content": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_30b\" \"opt_66b\" \"opt_175b\")\ntask_ids=(\"causalLM\")\n\n# ANSI color codes\nGREEN='\\033[0;32m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nfor model_name in \"${model_names[@]}\"\ndo\n    for task_id in \"${task_ids[@]}\"\n    do\n        echo \"Testing model_name: $model_name, task_id: $task_id\"\n        \n        CMD1=\"python test/mezo_sgd/hf_opt/test_speed.py --model_name $model_name --task $task_id --zo_method zo --max_steps 30 --eval\"\n        CMD2=\"python test/mezo_sgd/hf_opt/test_speed.py --model_name $model_name --task $task_id --zo_method zo2 --max_steps 30 --eval\"\n\n        OUT1=\"/tmp/output1_${model_name}_${task_id}.txt\"\n        OUT2=\"/tmp/output2_${model_name}_${task_id}.txt\"\n\n        $CMD1 2>&1 | tee $OUT1\n        $CMD2 2>&1 | tee $OUT2\n\n        echo \"Analyzing throughput...\"\n        \n        # Count the total number of lines and determine the number of iteration lines\n        total_lines1=$(wc -l < $OUT1)\n        total_lines2=$(wc -l < $OUT2)\n        iter_lines1=$(grep -c 'Time cost after iteration' $OUT1)\n        iter_lines2=$(grep -c 'Time cost after iteration' $OUT2)\n\n        # Calculate the starting line for the last 50% of iterations\n        start_line1=$(($total_lines1 - $iter_lines1 + $(($iter_lines1 / 2 + 1))))\n        start_line2=$(($total_lines2 - $iter_lines2 + $(($iter_lines2 / 2 + 1))))\n\n        # Calculate average tokens per second for the last 50% of the iterations\n        avg_tok_s1=$(tail -n +$start_line1 $OUT1 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}')\n        avg_tok_s2=$(tail -n +$start_line2 $OUT2 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}')\n\n        ratio=$(echo \"scale=2; $avg_tok_s2 / $avg_tok_s1 * 100\" | bc)\n\n        echo -e \"Model: $model_name, Task: $task_id\"\n        echo -e \"ZO average throughput (last 50% iterations): ${GREEN}$avg_tok_s1 tok/s${NC}\"\n        echo -e \"ZO2 average throughput (last 50% iterations): ${GREEN}$avg_tok_s2 tok/s${NC}\"\n        echo -e \"Throughput ratio of ZO2 to ZO (last 50% iterations): ${GREEN}$ratio%${NC}\"\n\n        rm $OUT1 $OUT2\n    done\ndone"
  },
  {
    "path": "test/mezo_sgd/hf_opt/test_speed_train.sh",
    "content": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_30b\" \"opt_66b\" \"opt_175b\")\ntask_ids=(\"causalLM\")\n\n# ANSI color codes\nGREEN='\\033[0;32m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nfor model_name in \"${model_names[@]}\"\ndo\n    for task_id in \"${task_ids[@]}\"\n    do\n        echo \"Testing model_name: $model_name, task_id: $task_id\"\n        \n        CMD1=\"python test/mezo_sgd/hf_opt/test_speed.py --model_name $model_name --task $task_id --zo_method zo --max_steps 30\"\n        CMD2=\"python test/mezo_sgd/hf_opt/test_speed.py --model_name $model_name --task $task_id --zo_method zo2 --max_steps 30\"\n\n        OUT1=\"/tmp/output1_${model_name}_${task_id}.txt\"\n        OUT2=\"/tmp/output2_${model_name}_${task_id}.txt\"\n\n        $CMD1 2>&1 | tee $OUT1\n        $CMD2 2>&1 | tee $OUT2\n\n        echo \"Analyzing throughput...\"\n        \n        # Count the total number of lines and determine the number of iteration lines\n        total_lines1=$(wc -l < $OUT1)\n        total_lines2=$(wc -l < $OUT2)\n        iter_lines1=$(grep -c 'Time cost after iteration' $OUT1)\n        iter_lines2=$(grep -c 'Time cost after iteration' $OUT2)\n\n        # Calculate the starting line for the last 50% of iterations\n        start_line1=$(($total_lines1 - $iter_lines1 + $(($iter_lines1 / 2 + 1))))\n        start_line2=$(($total_lines2 - $iter_lines2 + $(($iter_lines2 / 2 + 1))))\n\n        # Calculate average tokens per second for the last 50% of the iterations\n        avg_tok_s1=$(tail -n +$start_line1 $OUT1 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}')\n        avg_tok_s2=$(tail -n +$start_line2 $OUT2 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}')\n\n        ratio=$(echo \"scale=2; $avg_tok_s2 / $avg_tok_s1 * 100\" | bc)\n\n        echo -e \"Model: $model_name, Task: $task_id\"\n        echo -e \"ZO average throughput (last 50% iterations): ${GREEN}$avg_tok_s1 tok/s${NC}\"\n        echo -e \"ZO2 average throughput (last 50% iterations): ${GREEN}$avg_tok_s2 tok/s${NC}\"\n        echo -e \"Throughput ratio of ZO2 to ZO (last 50% iterations): ${GREEN}$ratio%${NC}\"\n\n        rm $OUT1 $OUT2\n    done\ndone"
  },
  {
    "path": "test/mezo_sgd/hf_opt/utils.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\nimport time\nimport argparse\nfrom tqdm import tqdm\nimport psutil\nimport os\nfrom transformers import OPTConfig\nimport pynvml\n\ndef get_args():\n    args = argparse.ArgumentParser()\n    args.add_argument(\"--zo_method\", type=str, default=\"zo2\")\n    args.add_argument(\"--eval\", action=\"store_true\")\n    args.add_argument(\"--task\", type=str, default=\"causalLM\")\n    args.add_argument(\"--model_name\", type=str, default=\"opt_125m\")\n    args.add_argument(\"--model_dtype\", type=str, default=\"fp16\")\n    args.add_argument(\"--verbose\", action=\"store_true\")\n    args.add_argument(\"--max_steps\", type=int, default=3)\n    args.add_argument(\"--lr\", type=float, default=1e-3)\n    args.add_argument(\"--weight_decay\", type=float, default=1e-1)\n    args.add_argument(\"--zo_eps\", type=float, default=1e-3)\n    args.add_argument(\"--seed\", type=int, default=42)\n    args.add_argument(\"--batch_size\", type=int, default=1)\n    args.add_argument(\"--sequence_length\", type=int, default=2048)\n    args.add_argument(\"--overlap\", type=str, default=\"all\")\n    args.add_argument(\"--offloading_device\", type=str, default=\"cpu\")\n    args.add_argument(\"--working_device\", type=str, default=\"cuda:0\")\n    args = args.parse_args()\n    args.model_dtype = dtype_lookup[args.model_dtype]\n    return args\n\n\nclass OPTConfigs:\n    opt_125m: OPTConfig = OPTConfig(num_hidden_layers=12, num_attention_heads=12, hidden_size=768, ffn_dim=3072, max_position_embeddings=2048)\n    opt_350m: OPTConfig = OPTConfig(num_hidden_layers=24, num_attention_heads=16, hidden_size=1024, ffn_dim=4096, max_position_embeddings=2048)\n    opt_1_3b: OPTConfig = OPTConfig(num_hidden_layers=24, num_attention_heads=32, hidden_size=2048, ffn_dim=8192, max_position_embeddings=2048)\n    opt_2_7b: OPTConfig = OPTConfig(num_hidden_layers=32, num_attention_heads=32, hidden_size=2560, ffn_dim=10240, max_position_embeddings=2048)\n    opt_6_7b: OPTConfig = OPTConfig(num_hidden_layers=32, num_attention_heads=32, hidden_size=4096, ffn_dim=16384, max_position_embeddings=2048)\n    opt_13b: OPTConfig = OPTConfig(num_hidden_layers=40, num_attention_heads=40, hidden_size=5120, ffn_dim=20480, max_position_embeddings=2048)\n    opt_30b: OPTConfig = OPTConfig(num_hidden_layers=48, num_attention_heads=56, hidden_size=7168, ffn_dim=28672, max_position_embeddings=2048)\n    opt_66b: OPTConfig = OPTConfig(num_hidden_layers=64, num_attention_heads=72, hidden_size=9216, ffn_dim=36864, max_position_embeddings=2048)\n    opt_175b: OPTConfig = OPTConfig(num_hidden_layers=96, num_attention_heads=96, hidden_size=12288, ffn_dim=49152, max_position_embeddings=2048)\n\n\ndtype_lookup = {\n    \"fp64\": torch.float64,\n    \"fp32\": torch.float32,\n    \"fp16\": torch.float16,\n    \"bf16\": torch.bfloat16\n}\n\n\ndef model_size(model: torch.nn.Module):\n    total_size = sum(p.numel() for p in model.parameters())\n    trainable_size = sum(p.numel() for p in model.parameters() if p.requires_grad)\n    return {\"total\": total_size, \"trainable\": trainable_size}\n\n\ndef prepare_data_for_causalLM(V, B, T, device='cuda'):\n    data_batch = torch.randint(0, V, (B, T)).to(device)\n    input_ids = data_batch\n    labels = data_batch\n    return input_ids, labels\n\ndef prepare_data_for_sequence_classification(V, B, T, device='cuda'):\n    input_ids = torch.randint(0, V, (B, T)).to(device)\n    labels = torch.randint(0, 1, (B, )).to(device)\n    return input_ids, labels\n\ndef prepare_data_for_question_answering(V, B, T, device='cuda'):\n    input_ids = torch.randint(0, V, (B, T)).to(device)\n    start_positions = torch.randint(0, 1, (B, )).to(device)\n    end_positions = torch.randint(1, 2, (B, )).to(device)\n    return input_ids, start_positions, end_positions\n\n\n# GPU Memory Monitoring\npynvml.nvmlInit()\ndef check_peak_gpu_memory_usage(iter, device=0, use_tqdm=False):\n    # Check the peak memory usage\n    handle = pynvml.nvmlDeviceGetHandleByIndex(device)  # Adjust index if necessary\n    info = pynvml.nvmlDeviceGetMemoryInfo(handle)\n    peak_memory = info.used / 1024**2\n    if use_tqdm:\n        tqdm.write(\"Peak GPU Memory after iteration {}: {:.2f} MB\".format(iter+1, peak_memory))\n    else:\n        print(f\"Peak GPU Memory after iteration {iter+1}: {peak_memory:.2f} MB\")\n\n# CPU Memory Monitoring\npeak_memory_cpu = 0\ndef check_and_update_peak_cpu_memory_usage(iter, use_tqdm=False):\n    global peak_memory_cpu\n    process = psutil.Process(os.getpid())\n    current_memory = process.memory_info().rss / (1024 ** 2)  # Convert to MB\n    if current_memory > peak_memory_cpu:\n        peak_memory_cpu = current_memory\n    if use_tqdm:\n        tqdm.write(f\"Peak CPU Memory after iteration {iter+1}: {peak_memory_cpu:.2f} MB\")\n    else:\n        print(f\"Peak CPU Memory after iteration {iter+1}: {peak_memory_cpu:.2f} MB\")\n\ndef reset_peak_cpu_memory_usage():\n    global peak_memory_cpu\n    peak_memory_cpu = 0\n    if torch.cuda.is_available():\n        torch.cuda.reset_peak_memory_stats()\n\n\ndef check_throughput(iter, total_token_batch_size_per_iter, fn, *args, use_tqdm=False, **kwargs):\n    t1 = time.time()\n    out = fn(*args, **kwargs)\n    t2 = time.time()\n    time_cost = t2-t1\n    throughtput = total_token_batch_size_per_iter / time_cost\n    if use_tqdm:\n        tqdm.write(\"Time cost after iteration {}: {:.2f} ms, {:.2f} tok/s\".format(iter+1, time_cost*1e3, throughtput))\n    else:\n        print(\"Time cost after iteration {}: {:.2f} ms, {:.2f} tok/s\".format(iter+1, time_cost*1e3, throughtput))\n"
  },
  {
    "path": "test/mezo_sgd/hf_qwen3/record_zo2_memory.sh",
    "content": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"qwen3_0_6b\" \"qwen3_1_7b\" \"qwen3_4b\" \"qwen3_8b\" \"qwen3_14b\" \"qwen3_32b\")\ntask_ids=(\"causalLM\")\n\n# ANSI color codes\nGREEN='\\033[0;32m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nfor model_name in \"${model_names[@]}\"\ndo\n    for task_id in \"${task_ids[@]}\"\n    do\n        echo \"Testing model_name: $model_name, task_id: $task_id\"\n        \n        CMD2=\"python test/mezo_sgd/hf_qwen3/test_memory.py --model_name $model_name --zo_method zo2 --max_steps 30\"\n\n        OUT2=\"/tmp/output2_${model_name}_${task_id}.txt\"\n\n        $CMD2 2>&1 | tee $OUT2\n\n        echo \"Recording Peak GPU and CPU Memory usage...\"\n        max_mem1=$(grep 'Peak GPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1)\n        max_mem2=$(grep 'Peak CPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1)\n\n        if [ -z \"$max_mem1\" ] || [ -z \"$max_mem2\" ]; then\n            echo \"Could not find memory usage data in the output.\"\n        else\n            echo -e \"Model: $model_name, Task: $task_id\"\n            echo -e \"ZO2 peak GPU memory: ${GREEN}$max_mem1 MB${NC}\"\n            echo -e \"ZO2 peak CPU memory: ${GREEN}$max_mem2 MB${NC}\"\n        fi\n\n        rm $OUT2\n    done\ndone"
  },
  {
    "path": "test/mezo_sgd/hf_qwen3/record_zo2_speed.sh",
    "content": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"qwen3_0_6b\" \"qwen3_1_7b\" \"qwen3_4b\" \"qwen3_8b\" \"qwen3_14b\" \"qwen3_32b\")\ntask_ids=(\"causalLM\")\n\n# ANSI color codes\nGREEN='\\033[0;32m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nfor model_name in \"${model_names[@]}\"\ndo\n    for task_id in \"${task_ids[@]}\"\n    do\n        echo \"Testing model_name: $model_name, task_id: $task_id\"\n        \n        CMD2=\"python test/mezo_sgd/hf_qwen3/test_speed.py --model_name $model_name --zo_method zo2 --max_steps 30\"\n\n        OUT2=\"/tmp/output2_${model_name}_${task_id}.txt\"\n\n        $CMD2 2>&1 | tee $OUT2\n\n        echo \"Recording throughput...\"\n        \n        # Count the total number of lines and determine the number of iteration lines\n        total_lines2=$(wc -l < $OUT2)\n        iter_lines2=$(grep -c 'Time cost after iteration' $OUT2)\n\n        # Calculate the starting line for the last 50% of iterations\n        start_line2=$(($total_lines2 - $iter_lines2 + $(($iter_lines2 / 2 + 1))))\n\n        # Calculate average tokens per second for the last 50% of the iterations\n        avg_tok_s2=$(tail -n +$start_line2 $OUT2 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}')\n\n        echo -e \"Model: $model_name, Task: $task_id\"\n        echo -e \"ZO2 average throughput (last 50% iterations): ${GREEN}$avg_tok_s2 tok/s${NC}\"\n        \n        rm $OUT2\n    done\ndone"
  },
  {
    "path": "test/mezo_sgd/hf_qwen3/test_acc.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport sys\nsys.path.append(\"../zo2\")\n\nimport torch\nfrom tqdm import tqdm\n\nfrom zo2.config.mezo_sgd import MeZOSGDConfig\nfrom zo2.model.huggingface.qwen3.mezo_sgd import zo, zo2\nfrom zo2.utils.utils import seed_everything\nfrom utils import (\n    Qwen3Configs,\n    prepare_data_for_causalLM, \n    model_size, \n    get_args\n)\n\ndef train_mezo_sgd_causalLM(model_config, zo_config, device='cuda'):\n    input_ids, labels = prepare_data_for_causalLM(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo.Qwen3ForCausalLM(model_config).to(device)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        model.zo_train()\n        loss = model(input_ids=input_ids, labels=labels)\n        res = \"Iteration {}, loss: {}, projected grad: {}\"\n        tqdm.write(res.format(i, loss, model.opt.projected_grad))\n\ndef train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'):\n    input_ids, labels = prepare_data_for_causalLM(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo2.Qwen3ForCausalLM(model_config)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        model.zo_train()\n        loss = model(input_ids=input_ids, labels=labels)\n        res = \"Iteration {}, loss: {}, projected grad: {}\"\n        tqdm.write(res.format(i, loss, model.opt.projected_grad))\n\ndef eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda'):\n    input_ids, labels = prepare_data_for_causalLM(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo.Qwen3ForCausalLM(model_config).to(device)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        model.zo_eval()\n        loss = model(input_ids=input_ids, labels=labels)[\"loss\"]\n        res = \"Iteration {}, loss: {}\"\n        tqdm.write(res.format(i, loss))\n\ndef eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'):\n    input_ids, labels = prepare_data_for_causalLM(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo2.Qwen3ForCausalLM(model_config)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        model.zo_eval()\n        loss = model(input_ids=input_ids, labels=labels)[\"loss\"]\n        res = \"Iteration {}, loss: {}\"\n        tqdm.write(res.format(i, loss))\n\n\ndef test_mezo_sgd_causalLM_training():\n    seed_everything(args.seed)\n    model_configs = Qwen3Configs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    # model_config._attn_implementation = \"eager\"\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    train_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo2_sgd_causalLM_training():\n    seed_everything(args.seed)\n    model_configs = Qwen3Configs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    # model_config._attn_implementation = \"eager\"\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    zo_cfg.overlap = args.overlap==\"all\"\n    train_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo_sgd_causalLM_eval():\n    seed_everything(args.seed)\n    model_configs = Qwen3Configs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    eval_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo2_sgd_causalLM_eval():\n    seed_everything(args.seed)\n    model_configs = Qwen3Configs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    zo_cfg.overlap = args.overlap==\"all\"\n    eval_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device)\n\n\nif __name__==\"__main__\":\n    args = get_args()\n    # torch.set_printoptions(precision=10)\n    original_dtype = torch.get_default_dtype()\n    if args.zo_method == \"zo\":\n        if args.eval:\n            test_mezo_sgd_causalLM_eval()\n        else:\n            test_mezo_sgd_causalLM_training()\n    elif args.zo_method == \"zo2\":\n        if args.eval:\n            test_mezo2_sgd_causalLM_eval()\n        else:\n            test_mezo2_sgd_causalLM_training()\n    else:\n        raise NotImplementedError"
  },
  {
    "path": "test/mezo_sgd/hf_qwen3/test_acc_eval.sh",
    "content": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"qwen3_0_6b\" \"qwen3_1_7b\" \"qwen3_4b\" \"qwen3_8b\" \"qwen3_14b\" \"qwen3_32b\")\ntask_ids=(\"causalLM\")\n\n# ANSI color codes\nGREEN='\\033[0;32m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nfor model_name in \"${model_names[@]}\"\ndo\n    for task_id in \"${task_ids[@]}\"\n    do\n        echo \"Testing model_name: $model_name, task_id: $task_id\"\n        \n        if [ \"$task_id\" == \"causalLM\" ]; then\n            lr=1e-4\n        else\n            lr=1e-7\n        fi\n\n        CMD1=\"python test/mezo_sgd/hf_qwen3/test_acc.py --model_name $model_name --zo_method zo --lr $lr --eval --max_steps 30\"\n        CMD2=\"python test/mezo_sgd/hf_qwen3/test_acc.py --model_name $model_name --zo_method zo2 --lr $lr --eval --max_steps 30\"\n\n        OUT1=\"/tmp/output1_${model_name}_${task_id}.txt\"\n        OUT2=\"/tmp/output2_${model_name}_${task_id}.txt\"\n\n        $CMD1 2>&1 | tee $OUT1\n        $CMD2 2>&1 | tee $OUT2\n\n        echo \"Comparing outputs...\"\n        echo -e \"Model: $model_name, Task: $task_id\"\n        paste <(grep 'Iteration' $OUT1) <(grep 'Iteration' $OUT2) | awk -v green=\"$GREEN\" -v red=\"$RED\" -v nc=\"$NC\" '{\n            split($4, loss1, \",\");\n            split($8, loss2, \",\");\n            diff_loss = loss1[1] - loss2[1];\n            if (loss1[1] == loss2[1])\n                printf \"Iteration %s: %s✓ loss match.%s\\n\", $2, green, nc;\n            else\n                printf \"Iteration %s: %s✗ Mismatch! ZO (loss): (%s), ZO2 (loss): (%s) Loss diff: %.6f%s\\n\", $2, red, loss1[1], loss2[1], diff_loss, nc;\n        }'\n\n        rm $OUT1 $OUT2\n    done\ndone"
  },
  {
    "path": "test/mezo_sgd/hf_qwen3/test_acc_train.sh",
    "content": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"qwen3_0_6b\" \"qwen3_1_7b\" \"qwen3_4b\" \"qwen3_8b\" \"qwen3_14b\" \"qwen3_32b\")\ntask_ids=(\"causalLM\")\n\n# ANSI color codes\nGREEN='\\033[0;32m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nfor model_name in \"${model_names[@]}\"\ndo\n    for task_id in \"${task_ids[@]}\"\n    do\n        echo \"Testing model_name: $model_name\"\n        \n        if [ \"$task_id\" == \"causalLM\" ]; then\n            lr=1e-4\n        else\n            lr=1e-7\n        fi\n\n        CMD1=\"python test/mezo_sgd/hf_qwen3/test_acc.py --model_name $model_name --zo_method zo --lr $lr\"\n        CMD2=\"python test/mezo_sgd/hf_qwen3/test_acc.py --model_name $model_name --zo_method zo2 --lr $lr\"\n\n        OUT1=\"/tmp/output1_${model_name}_${task_id}.txt\"\n        OUT2=\"/tmp/output2_${model_name}_${task_id}.txt\"\n\n        $CMD1 2>&1 | tee $OUT1\n        $CMD2 2>&1 | tee $OUT2\n\n        echo \"Comparing outputs...\"\n        echo -e \"Model: $model_name, Task: $task_id\"\n        paste <(grep 'Iteration' $OUT1) <(grep 'Iteration' $OUT2) | awk -v green=\"$GREEN\" -v red=\"$RED\" -v nc=\"$NC\" '{\n            split($4, loss1, \",\");\n            split($7, proj1, \",\");\n            split($11, loss2, \",\");\n            split($14, proj2, \",\");\n            diff_loss = loss1[1] - loss2[1];\n            diff_proj = proj1[1] - proj2[1];\n            if (loss1[1] == loss2[1] && proj1[1] == proj2[1])\n                printf \"Iteration %d: %s✓ loss and projected grad match.%s\\n\", $2, green, nc;\n            else\n                printf \"Iteration %d: %s✗ Mismatch! ZO (loss, grad): (%s, %s), ZO2 (loss, grad): (%s, %s)\\n \\tLoss diff: %.6f, Proj grad diff: %.6f%s\\n\", $2, red, loss1[1], proj1[1], loss2[1], proj2[1], diff_loss, diff_proj, nc;\n        }'\n\n        rm $OUT1 $OUT2\n    done\ndone"
  },
  {
    "path": "test/mezo_sgd/hf_qwen3/test_memory.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport sys\nsys.path.append(\"../zo2\")\n\nimport torch\nfrom tqdm import tqdm\n\nfrom zo2.config.mezo_sgd import MeZOSGDConfig\nfrom zo2.model.huggingface.qwen3.mezo_sgd import zo, zo2\nfrom zo2.utils.utils import seed_everything\nfrom utils import (\n    Qwen3Configs,\n    prepare_data_for_causalLM, \n    model_size, \n    get_args,\n    reset_peak_cpu_memory_usage,\n    check_peak_gpu_memory_usage,\n    check_and_update_peak_cpu_memory_usage,\n)\n\ndef train_mezo_sgd_causalLM(model_config, zo_config, device='cuda:0'):\n    input_ids, labels = prepare_data_for_causalLM(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo.Qwen3ForCausalLM(model_config).to(device)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    torch.cuda.reset_peak_memory_stats()\n    reset_peak_cpu_memory_usage()\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        model(input_ids=input_ids, labels=labels)\n        check_peak_gpu_memory_usage(i, int(device[-1]), True)\n        check_and_update_peak_cpu_memory_usage(i, True)\n\ndef train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda:0'):\n    input_ids, labels = prepare_data_for_causalLM(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo2.Qwen3ForCausalLM(model_config)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    torch.cuda.reset_peak_memory_stats()\n    reset_peak_cpu_memory_usage()\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        model(input_ids=input_ids, labels=labels)\n        check_peak_gpu_memory_usage(i, int(device[-1]), True)\n        check_and_update_peak_cpu_memory_usage(i, True)\n\ndef eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda:0'):\n    input_ids, labels = prepare_data_for_causalLM(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo.Qwen3ForCausalLM(model_config).to(device)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    torch.cuda.reset_peak_memory_stats()\n    reset_peak_cpu_memory_usage()\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        model(input_ids=input_ids, labels=labels)\n        check_peak_gpu_memory_usage(i, int(device[-1]), True)\n        check_and_update_peak_cpu_memory_usage(i, True)\n\ndef eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda:0'):\n    input_ids, labels = prepare_data_for_causalLM(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo2.Qwen3ForCausalLM(model_config)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    torch.cuda.reset_peak_memory_stats()\n    reset_peak_cpu_memory_usage()\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        model(input_ids=input_ids, labels=labels)\n        check_peak_gpu_memory_usage(i, int(device[-1]), True)\n        check_and_update_peak_cpu_memory_usage(i, True)\n\n\n\ndef test_mezo_sgd_causalLM_training():\n    seed_everything(args.seed)\n    model_configs = Qwen3Configs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    train_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo2_sgd_causalLM_training():\n    seed_everything(args.seed)\n    model_configs = Qwen3Configs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    train_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo_sgd_causalLM_eval():\n    seed_everything(args.seed)\n    model_configs = Qwen3Configs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    eval_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo2_sgd_causalLM_eval():\n    seed_everything(args.seed)\n    model_configs = Qwen3Configs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    eval_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device)\n\n\n\nif __name__==\"__main__\":\n    args = get_args()\n    original_dtype = torch.get_default_dtype()\n    if args.zo_method == \"zo\":\n        if args.eval:\n            test_mezo_sgd_causalLM_eval()\n        else:\n            test_mezo_sgd_causalLM_training()\n    elif args.zo_method == \"zo2\":\n        if args.eval:\n            test_mezo2_sgd_causalLM_eval()\n        else:\n            test_mezo2_sgd_causalLM_training()\n    else:\n        raise NotImplementedError"
  },
  {
    "path": "test/mezo_sgd/hf_qwen3/test_memory_train.sh",
    "content": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"qwen3_0_6b\" \"qwen3_1_7b\" \"qwen3_4b\" \"qwen3_8b\" \"qwen3_14b\" \"qwen3_32b\")\ntask_ids=(\"causalLM\")\n\n# ANSI color codes\nGREEN='\\033[0;32m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nfor model_name in \"${model_names[@]}\"\ndo\n    for task_id in \"${task_ids[@]}\"\n    do\n        echo \"Testing model_name: $model_name, task_id: $task_id\"\n        \n        CMD1=\"python test/mezo_sgd/hf_qwen3/test_memory.py --model_name $model_name --zo_method zo --max_steps 30\"\n        CMD2=\"python test/mezo_sgd/hf_qwen3/test_memory.py --model_name $model_name --zo_method zo2 --max_steps 30\"\n\n        OUT1=\"/tmp/output1_${model_name}_${task_id}.txt\"\n        OUT2=\"/tmp/output2_${model_name}_${task_id}.txt\"\n\n        $CMD1 2>&1 | tee $OUT1\n        $CMD2 2>&1 | tee $OUT2\n\n        echo \"Analyzing Peak GPU Memory usage...\"\n        max_mem1=$(grep 'Peak GPU Memory' $OUT1 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1)\n        max_mem2=$(grep 'Peak GPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1)\n\n        if [ -z \"$max_mem1\" ] || [ -z \"$max_mem2\" ]; then\n            echo \"Could not find memory usage data in the output.\"\n        else\n            ratio=$(echo \"scale=2; $max_mem2 / $max_mem1 * 100\" | bc)\n            echo -e \"Model: $model_name, Task: $task_id\"\n            echo -e \"ZO peak GPU memory: ${GREEN}$max_mem1 MB${NC}\"\n            echo -e \"ZO2 peak GPU memory: ${GREEN}$max_mem2 MB${NC}\"\n            echo -e \"Memory usage ratio of ZO2 to ZO: ${GREEN}$ratio%${NC}\"\n        fi\n\n        rm $OUT1 $OUT2\n    done\ndone"
  },
  {
    "path": "test/mezo_sgd/hf_qwen3/test_speed.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport sys\nsys.path.append(\"../zo2\")\n\nimport torch\nfrom tqdm import tqdm\n\nfrom zo2.config.mezo_sgd import MeZOSGDConfig\nfrom zo2.model.huggingface.qwen3.mezo_sgd import zo, zo2\nfrom zo2.utils.utils import seed_everything\nfrom utils import (\n    Qwen3Configs,\n    prepare_data_for_causalLM, \n    model_size, \n    get_args,\n    check_throughput\n)\n\ndef train_mezo_sgd_causalLM(model_config, zo_config, device='cuda'):\n    input_ids, labels = prepare_data_for_causalLM(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo.Qwen3ForCausalLM(model_config).to(device)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True)\n\ndef train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'):\n    input_ids, labels = prepare_data_for_causalLM(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo2.Qwen3ForCausalLM(model_config)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True)\n\ndef eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda'):\n    input_ids, labels = prepare_data_for_causalLM(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo.Qwen3ForCausalLM(model_config).to(device)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True)\n\ndef eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'):\n    input_ids, labels = prepare_data_for_causalLM(\n        model_config.vocab_size, args.batch_size, model_config.max_position_embeddings, device)\n    torch.set_default_dtype(args.model_dtype)\n    model = zo2.Qwen3ForCausalLM(model_config)\n    model.zo_init(zo_config)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    torch.set_default_dtype(original_dtype)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        check_throughput(i, args.batch_size*model_config.max_position_embeddings, model, input_ids=input_ids, labels=labels, use_tqdm=True)\n\n\ndef test_mezo_sgd_causalLM_training():\n    seed_everything(args.seed)\n    model_configs = Qwen3Configs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    train_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo2_sgd_causalLM_training():\n    seed_everything(args.seed)\n    model_configs = Qwen3Configs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    train_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo_sgd_causalLM_eval():\n    seed_everything(args.seed)\n    model_configs = Qwen3Configs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    eval_mezo_sgd_causalLM(model_config, zo_cfg, device=args.working_device)\n\ndef test_mezo2_sgd_causalLM_eval():\n    seed_everything(args.seed)\n    model_configs = Qwen3Configs()\n    model_config = getattr(model_configs, args.model_name)\n    model_config.tie_word_embeddings=False\n    model_config.max_position_embeddings = args.sequence_length\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    eval_mezo2_sgd_causalLM(model_config, zo_cfg, device=args.working_device)\n\n\nif __name__==\"__main__\":\n    args = get_args()\n    original_dtype = torch.get_default_dtype()\n    if args.zo_method == \"zo\":\n        if args.eval:\n            test_mezo_sgd_causalLM_eval()\n        else:\n            test_mezo_sgd_causalLM_training()\n    elif args.zo_method == \"zo2\":\n        if args.eval:\n            test_mezo2_sgd_causalLM_eval()\n        else:\n            test_mezo2_sgd_causalLM_training()\n    else:\n        raise NotImplementedError"
  },
  {
    "path": "test/mezo_sgd/hf_qwen3/test_speed_train.sh",
    "content": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_names=(\"qwen3_0_6b\" \"qwen3_1_7b\" \"qwen3_4b\" \"qwen3_8b\" \"qwen3_14b\" \"qwen3_32b\")\ntask_ids=(\"causalLM\")\n\n# ANSI color codes\nGREEN='\\033[0;32m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nfor model_name in \"${model_names[@]}\"\ndo\n    for task_id in \"${task_ids[@]}\"\n    do\n        echo \"Testing model_name: $model_name, task_id: $task_id\"\n        \n        CMD1=\"python test/mezo_sgd/hf_qwen3/test_speed.py --model_name $model_name --zo_method zo --max_steps 30\"\n        CMD2=\"python test/mezo_sgd/hf_qwen3/test_speed.py --model_name $model_name --zo_method zo2 --max_steps 30\"\n\n        OUT1=\"/tmp/output1_${model_name}_${task_id}.txt\"\n        OUT2=\"/tmp/output2_${model_name}_${task_id}.txt\"\n\n        $CMD1 2>&1 | tee $OUT1\n        $CMD2 2>&1 | tee $OUT2\n\n        echo \"Analyzing throughput...\"\n        \n        # Count the total number of lines and determine the number of iteration lines\n        total_lines1=$(wc -l < $OUT1)\n        total_lines2=$(wc -l < $OUT2)\n        iter_lines1=$(grep -c 'Time cost after iteration' $OUT1)\n        iter_lines2=$(grep -c 'Time cost after iteration' $OUT2)\n\n        # Calculate the starting line for the last 50% of iterations\n        start_line1=$(($total_lines1 - $iter_lines1 + $(($iter_lines1 / 2 + 1))))\n        start_line2=$(($total_lines2 - $iter_lines2 + $(($iter_lines2 / 2 + 1))))\n\n        # Calculate average tokens per second for the last 50% of the iterations\n        avg_tok_s1=$(tail -n +$start_line1 $OUT1 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}')\n        avg_tok_s2=$(tail -n +$start_line2 $OUT2 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}')\n\n        ratio=$(echo \"scale=2; $avg_tok_s2 / $avg_tok_s1 * 100\" | bc)\n\n        echo -e \"Model: $model_name, Task: $task_id\"\n        echo -e \"ZO average throughput (last 50% iterations): ${GREEN}$avg_tok_s1 tok/s${NC}\"\n        echo -e \"ZO2 average throughput (last 50% iterations): ${GREEN}$avg_tok_s2 tok/s${NC}\"\n        echo -e \"Throughput ratio of ZO2 to ZO (last 50% iterations): ${GREEN}$ratio%${NC}\"\n\n        rm $OUT1 $OUT2\n    done\ndone"
  },
  {
    "path": "test/mezo_sgd/hf_qwen3/utils.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\nimport time\nimport argparse\nfrom tqdm import tqdm\nimport psutil\nimport os\nfrom transformers import Qwen3Config\nimport pynvml\n\ndef get_args():\n    args = argparse.ArgumentParser()\n    args.add_argument(\"--zo_method\", type=str, default=\"zo2\")\n    args.add_argument(\"--eval\", action=\"store_true\")\n    args.add_argument(\"--model_name\", type=str, default=\"qwen3_0_6b\")\n    args.add_argument(\"--model_dtype\", type=str, default=\"fp16\")\n    args.add_argument(\"--verbose\", action=\"store_true\")\n    args.add_argument(\"--max_steps\", type=int, default=3)\n    args.add_argument(\"--lr\", type=float, default=1e-3)\n    args.add_argument(\"--weight_decay\", type=float, default=1e-1)\n    args.add_argument(\"--zo_eps\", type=float, default=1e-3)\n    args.add_argument(\"--seed\", type=int, default=42)\n    args.add_argument(\"--batch_size\", type=int, default=1)\n    args.add_argument(\"--sequence_length\", type=int, default=2048)\n    args.add_argument(\"--overlap\", type=str, default=\"all\")\n    args.add_argument(\"--offloading_device\", type=str, default=\"cpu\")\n    args.add_argument(\"--working_device\", type=str, default=\"cuda:0\")\n    args = args.parse_args()\n    args.model_dtype = dtype_lookup[args.model_dtype]\n    return args\n\n\nclass Qwen3Configs:\n    qwen3_0_6b: Qwen3Config = Qwen3Config(num_hidden_layers=28, num_attention_heads=16, num_key_value_heads=8, max_window_layers=28, hidden_size=1024, intermediate_size=3072, max_position_embeddings=40960, use_sliding_window=False)\n    qwen3_1_7b: Qwen3Config = Qwen3Config(num_hidden_layers=28, num_attention_heads=16, num_key_value_heads=8, max_window_layers=28, hidden_size=2048, intermediate_size=6144, max_position_embeddings=40960, use_sliding_window=False)\n    qwen3_4b: Qwen3Config = Qwen3Config(num_hidden_layers=36, num_attention_heads=32, num_key_value_heads=8, max_window_layers=36, hidden_size=2560, intermediate_size=9728, max_position_embeddings=40960, use_sliding_window=False)\n    qwen3_8b: Qwen3Config = Qwen3Config(num_hidden_layers=36, num_attention_heads=32, num_key_value_heads=8, max_window_layers=36, hidden_size=4096, intermediate_size=12288, max_position_embeddings=40960, use_sliding_window=False)\n    qwen3_14b: Qwen3Config = Qwen3Config(num_hidden_layers=40, num_attention_heads=40, num_key_value_heads=8, max_window_layers=40, hidden_size=5120, intermediate_size=17408, max_position_embeddings=40960, use_sliding_window=False)\n    qwen3_32b: Qwen3Config = Qwen3Config(num_hidden_layers=64, num_attention_heads=64, num_key_value_heads=8, max_window_layers=64, hidden_size=5120, intermediate_size=25600, max_position_embeddings=40960, use_sliding_window=False)\n\n\ndtype_lookup = {\n    \"fp64\": torch.float64,\n    \"fp32\": torch.float32,\n    \"fp16\": torch.float16,\n    \"bf16\": torch.bfloat16\n}\n\n\ndef model_size(model: torch.nn.Module):\n    total_size = sum(p.numel() for p in model.parameters())\n    trainable_size = sum(p.numel() for p in model.parameters() if p.requires_grad)\n    return {\"total\": total_size, \"trainable\": trainable_size}\n\n\ndef prepare_data_for_causalLM(V, B, T, device='cuda'):\n    data_batch = torch.randint(0, V, (B, T)).to(device)\n    input_ids = data_batch\n    labels = data_batch\n    return input_ids, labels\n\n\n# GPU Memory Monitoring\npynvml.nvmlInit()\ndef check_peak_gpu_memory_usage(iter, device=0, use_tqdm=False):\n    # Check the peak memory usage\n    handle = pynvml.nvmlDeviceGetHandleByIndex(device)  # Adjust index if necessary\n    info = pynvml.nvmlDeviceGetMemoryInfo(handle)\n    peak_memory = info.used / 1024**2\n    if use_tqdm:\n        tqdm.write(\"Peak GPU Memory after iteration {}: {:.2f} MB\".format(iter+1, peak_memory))\n    else:\n        print(f\"Peak GPU Memory after iteration {iter+1}: {peak_memory:.2f} MB\")\n\n# CPU Memory Monitoring\npeak_memory_cpu = 0\ndef check_and_update_peak_cpu_memory_usage(iter, use_tqdm=False):\n    global peak_memory_cpu\n    process = psutil.Process(os.getpid())\n    current_memory = process.memory_info().rss / (1024 ** 2)  # Convert to MB\n    if current_memory > peak_memory_cpu:\n        peak_memory_cpu = current_memory\n    if use_tqdm:\n        tqdm.write(f\"Peak CPU Memory after iteration {iter+1}: {peak_memory_cpu:.2f} MB\")\n    else:\n        print(f\"Peak CPU Memory after iteration {iter+1}: {peak_memory_cpu:.2f} MB\")\n\ndef reset_peak_cpu_memory_usage():\n    global peak_memory_cpu\n    peak_memory_cpu = 0\n    if torch.cuda.is_available():\n        torch.cuda.reset_peak_memory_stats()\n\n\ndef check_throughput(iter, total_token_batch_size_per_iter, fn, *args, use_tqdm=False, **kwargs):\n    t1 = time.time()\n    out = fn(*args, **kwargs)\n    t2 = time.time()\n    time_cost = t2-t1\n    throughtput = total_token_batch_size_per_iter / time_cost\n    if use_tqdm:\n        tqdm.write(\"Time cost after iteration {}: {:.2f} ms, {:.2f} tok/s\".format(iter+1, time_cost*1e3, throughtput))\n    else:\n        print(\"Time cost after iteration {}: {:.2f} ms, {:.2f} tok/s\".format(iter+1, time_cost*1e3, throughtput))\n"
  },
  {
    "path": "test/mezo_sgd/nanogpt/record_zo2_memory.sh",
    "content": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_ids=(\"gpt2\" \"gpt2_medium\" \"gpt2_large\" \"gpt2_xl\" \"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_30b\" \"opt_66b\" \"opt_175b\")\n\n# ANSI color codes\nGREEN='\\033[0;32m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nfor model_id in \"${model_ids[@]}\"\ndo\n    echo \"Testing model_id: $model_id\"\n    \n    CMD2=\"python test/mezo_sgd/nanogpt/test_memory.py --model_id $model_id --zo_method zo2 --max_steps 30\"\n\n    OUT2=\"/tmp/output2_$model_id.txt\"\n\n    $CMD2 2>&1 | tee $OUT2\n\n    echo \"Analyzing Peak GPU and CPU Memory usage...\"\n    max_mem1=$(grep 'Peak GPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1)\n    max_mem2=$(grep 'Peak CPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1)\n\n    if [ -z \"$max_mem1\" ] || [ -z \"$max_mem2\" ]; then\n        echo \"Could not find memory usage data in the output.\"\n    else\n        echo -e \"Model: $model_name\"\n        echo -e \"ZO2 peak GPU memory: ${GREEN}$max_mem1 MB${NC}\"\n        echo -e \"ZO2 peak CPU memory: ${GREEN}$max_mem2 MB${NC}\"\n    fi\n\n    rm $OUT2\ndone"
  },
  {
    "path": "test/mezo_sgd/nanogpt/record_zo2_speed.sh",
    "content": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_ids=(\"gpt2\" \"gpt2_medium\" \"gpt2_large\" \"gpt2_xl\" \"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_30b\" \"opt_66b\" \"opt_175b\")\n\n# ANSI color codes\nGREEN='\\033[0;32m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nfor model_id in \"${model_ids[@]}\"\ndo\n    echo \"Testing model_id: $model_id\"\n    \n    CMD2=\"python test/mezo_sgd/nanogpt/test_speed.py --model_id $model_id --zo_method zo2 --max_steps 30\"\n\n    OUT2=\"/tmp/output2_$model_id.txt\"\n\n    $CMD2 2>&1 | tee $OUT2\n\n    echo \"Analyzing throughput...\"\n    \n    # Count the total number of lines and determine the number of iteration lines\n    total_lines2=$(wc -l < $OUT2)\n    iter_lines2=$(grep -c 'Time cost after iteration' $OUT2)\n\n    # Calculate the starting line for the last 50% of iterations\n    start_line2=$(($total_lines2 - $iter_lines2 + $(($iter_lines2 / 2 + 1))))\n\n    # Calculate average tokens per second for the last 50% of the iterations\n    avg_tok_s2=$(tail -n +$start_line2 $OUT2 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}')\n\n    echo -e \"Model: $model_name\"\n    echo -e \"ZO2 average throughput (last 50% iterations): ${GREEN}$avg_tok_s2 tok/s${NC}\"\n    \n    rm $OUT2\ndone"
  },
  {
    "path": "test/mezo_sgd/nanogpt/test_acc.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport sys\nsys.path.append(\"../zo2\")\n\nimport torch\nfrom tqdm import tqdm\n\nfrom zo2.config.mezo_sgd import MeZOSGDConfig\nfrom zo2.model.nanogpt.mezo_sgd import get_nanogpt_mezo_sgd\nfrom zo2.model.nanogpt.model import GPTConfig, GPTConfigs\nfrom zo2.utils.utils import seed_everything\nfrom utils import model_size, prepare_data, get_args\n\ndef train_mezo_sgd(model, args, model_config, device='cuda'):\n    seed_everything(args.seed)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    print(\"Init dataset\")\n    input_ids, pos, labels = prepare_data(model_config.vocab_size, args.batch_size, model_config.block_size, device=device)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        loss = model(input_ids, pos, labels)\n        res = \"Iteration {}, loss: {}, projected grad: {}\"\n        tqdm.write(res.format(i, loss, model.opt.projected_grad))\n\ndef train_mezo2_sgd(model, args, model_config, device='cuda'):\n    seed_everything(args.seed)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    print(\"Init dataset\")\n    input_ids, pos, labels = prepare_data(model_config.vocab_size, args.batch_size, model_config.block_size, device=device)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        loss = model(input_ids, pos, labels)\n        res = \"Iteration {}, loss: {}, projected grad: {}\"\n        tqdm.write(res.format(i, loss, model.opt.projected_grad))\n\ndef eval_mezo_sgd(model, args, model_config, device='cuda'):\n    seed_everything(args.seed)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    print(\"Init dataset\")\n    input_ids, pos, labels = prepare_data(model_config.vocab_size, args.batch_size, model_config.block_size, device=device)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        loss = model(input_ids, pos, labels)[-1]\n        res = \"Iteration {}, loss: {}\"\n        tqdm.write(res.format(i, loss))\n\ndef eval_mezo2_sgd(model, args, model_config, device='cuda'):\n    seed_everything(args.seed)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    print(\"Init dataset\")\n    input_ids, pos, labels = prepare_data(model_config.vocab_size, args.batch_size, model_config.block_size, device=device)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        loss = model(input_ids, pos, labels)[-1]\n        res = \"Iteration {}, loss: {}\"\n        tqdm.write(res.format(i, loss))\n\ndef test_mezo_sgd_training():\n    seed_everything(args.seed)\n    cfgs = GPTConfigs()\n    cfg = getattr(cfgs, args.model_id)\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    torch.set_default_dtype(args.model_dtype)\n    model_mezo = get_nanogpt_mezo_sgd(zo_cfg)(cfg, zo_cfg).to(args.working_device)\n    torch.set_default_dtype(original_dtype)\n    train_mezo_sgd(model=model_mezo, \n               args=args, \n               model_config=cfg, \n               device=args.working_device)\n\ndef test_mezo2_sgd_training():\n    seed_everything(args.seed)\n    cfgs = GPTConfigs()\n    cfg = getattr(cfgs, args.model_id)\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    torch.set_default_dtype(args.model_dtype)\n    model = get_nanogpt_mezo_sgd(zo_cfg)(cfg, zo_cfg)\n    torch.set_default_dtype(original_dtype)\n    train_mezo2_sgd(model=model, \n                          args=args, \n                          model_config=cfg, \n                          device=args.working_device)\n\ndef test_mezo_sgd_eval():\n    seed_everything(args.seed)\n    cfgs = GPTConfigs()\n    cfg = getattr(cfgs, args.model_id)\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    torch.set_default_dtype(args.model_dtype)\n    model_mezo = get_nanogpt_mezo_sgd(zo_cfg)(cfg, zo_cfg).to(args.working_device)\n    torch.set_default_dtype(original_dtype)\n    eval_mezo_sgd(model=model_mezo, \n               args=args, \n               model_config=cfg, \n               device=args.working_device)\n\ndef test_mezo2_sgd_eval():\n    seed_everything(args.seed)\n    cfgs = GPTConfigs()\n    cfg = getattr(cfgs, args.model_id)\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    torch.set_default_dtype(args.model_dtype)\n    model = get_nanogpt_mezo_sgd(zo_cfg)(cfg, zo_cfg)\n    torch.set_default_dtype(original_dtype)\n    eval_mezo2_sgd(model=model, \n                          args=args, \n                          model_config=cfg, \n                          device=args.working_device)\n\n\nif __name__ == \"__main__\":\n    args = get_args()\n    original_dtype = torch.get_default_dtype()\n    if args.zo_method == \"zo\":\n        if args.eval:\n            test_mezo_sgd_eval()\n        else:\n            test_mezo_sgd_training()\n    elif args.zo_method == \"zo2\":\n        if args.eval:\n            test_mezo2_sgd_eval()\n        else:\n            test_mezo2_sgd_training()\n    else:\n        raise NotImplementedError"
  },
  {
    "path": "test/mezo_sgd/nanogpt/test_acc_eval.sh",
    "content": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_ids=(\"gpt2\" \"gpt2_medium\" \"gpt2_large\" \"gpt2_xl\" \"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_30b\" \"opt_66b\" \"opt_175b\")\n\n# ANSI color codes\nGREEN='\\033[0;32m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nfor model_id in \"${model_ids[@]}\"\ndo\n    echo \"Testing model_id: $model_id\"\n    \n    CMD1=\"python test/mezo_sgd/nanogpt/test_acc.py --model_id $model_id --zo_method zo --eval\"\n    CMD2=\"python test/mezo_sgd/nanogpt/test_acc.py --model_id $model_id --zo_method zo2 --eval\"\n\n    OUT1=\"/tmp/output1_$model_id.txt\"\n    OUT2=\"/tmp/output2_$model_id.txt\"\n\n    $CMD1 2>&1 | tee $OUT1\n    $CMD2 2>&1 | tee $OUT2\n\n    echo \"Comparing outputs...\"\n    echo -e \"Model: $model_id\"\n    paste <(grep 'Iteration' $OUT1) <(grep 'Iteration' $OUT2) | awk -v green=\"$GREEN\" -v red=\"$RED\" -v nc=\"$NC\" '{\n        split($2, loss1, \":\");\n        split($7, loss2, \":\");\n        diff_loss = loss1[2] - loss2[2];\n        if (loss1[2] == loss2[2])\n            printf \"Iteration %s: %s✓ loss match.%s\\n\", $2, green, nc;\n        else\n            printf \"Iteration %s: %s✗ Mismatch! ZO (loss): (%s), ZO2 (loss): (%s) \\tLoss diff: %.6f%s\\n\", $2, red, loss1[2], loss2[2], diff_loss, nc;\n    }'\n\n    rm $OUT1 $OUT2\ndone"
  },
  {
    "path": "test/mezo_sgd/nanogpt/test_acc_train.sh",
    "content": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_ids=(\"gpt2\" \"gpt2_medium\" \"gpt2_large\" \"gpt2_xl\" \"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_30b\" \"opt_66b\" \"opt_175b\")\n\n# ANSI color codes\nGREEN='\\033[0;32m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nfor model_id in \"${model_ids[@]}\"\ndo\n    echo \"Testing model_id: $model_id\"\n    \n    CMD1=\"python test/mezo_sgd/nanogpt/test_acc.py --model_id $model_id --zo_method zo\"\n    CMD2=\"python test/mezo_sgd/nanogpt/test_acc.py --model_id $model_id --zo_method zo2\"\n\n    OUT1=\"/tmp/output1_$model_id.txt\"\n    OUT2=\"/tmp/output2_$model_id.txt\"\n\n    $CMD1 2>&1 | tee $OUT1\n    $CMD2 2>&1 | tee $OUT2\n\n    echo \"Comparing outputs...\"\n    echo -e \"Model: $model_name\"\n    paste <(grep 'Iteration' $OUT1) <(grep 'Iteration' $OUT2) | awk -v green=\"$GREEN\" -v red=\"$RED\" -v nc=\"$NC\" '{\n        split($4, loss1, \",\");\n        split($7, proj1, \",\");\n        split($11, loss2, \",\");\n        split($14, proj2, \",\");\n        diff_loss = loss1[1] - loss2[1];\n        diff_proj = proj1[1] - proj2[1];\n        if (loss1[1] == loss2[1] && proj1[1] == proj2[1])\n            printf \"Iteration %d: %s✓ loss and projected grad match.%s\\n\", $2, green, nc;\n        else\n            printf \"Iteration %d: %s✗ Mismatch! ZO (loss, grad): (%s, %s), ZO2 (loss, grad): (%s, %s)\\n \\tLoss diff: %.6f, Proj grad diff: %.6f%s\\n\", $2, red, loss1[1], proj1[1], loss2[1], proj2[1], diff_loss, diff_proj, nc;\n    }'\n\n    rm $OUT1 $OUT2\ndone"
  },
  {
    "path": "test/mezo_sgd/nanogpt/test_memory.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport sys\nsys.path.append(\"../zo2\")\n\nimport torch\nfrom tqdm import tqdm\n\nfrom zo2.config.mezo_sgd import MeZOSGDConfig\nfrom zo2.model.nanogpt.mezo_sgd import get_nanogpt_mezo_sgd\nfrom zo2.model.nanogpt.model import GPTConfig, GPTConfigs\nfrom zo2.utils.utils import seed_everything\nfrom utils import model_size, prepare_data, get_args, check_peak_gpu_memory_usage, reset_peak_cpu_memory_usage, check_and_update_peak_cpu_memory_usage\n\ndef train_mezo_sgd(model, args, modelConfig, device='cuda:0'):\n    seed_everything(args.seed)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    print(\"Init dataset\")\n    input_ids, pos, labels = prepare_data(modelConfig.vocab_size, args.batch_size, modelConfig.block_size, device=device)\n    torch.cuda.reset_peak_memory_stats()\n    reset_peak_cpu_memory_usage()\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        model(input_ids, pos, labels)\n        check_peak_gpu_memory_usage(i, int(device[-1]), True)\n        check_and_update_peak_cpu_memory_usage(i, True)\n\ndef train_mezo2_sgd(model, args, modelConfig, device='cuda:0'):\n    seed_everything(args.seed)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    print(\"Init dataset\")\n    input_ids, pos, labels = prepare_data(modelConfig.vocab_size, args.batch_size, modelConfig.block_size, device=device)\n    torch.cuda.reset_peak_memory_stats()\n    reset_peak_cpu_memory_usage()\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        model(input_ids, pos, labels)\n        check_peak_gpu_memory_usage(i, int(device[-1]), True)\n        check_and_update_peak_cpu_memory_usage(i, True)\n\ndef eval_mezo_sgd(model, args, modelConfig, device='cuda:0'):\n    seed_everything(args.seed)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    print(\"Init dataset\")\n    input_ids, pos, labels = prepare_data(modelConfig.vocab_size, args.batch_size, modelConfig.block_size, device=device)\n    torch.cuda.reset_peak_memory_stats()\n    reset_peak_cpu_memory_usage()\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        model(input_ids, pos, labels)\n        check_peak_gpu_memory_usage(i, int(device[-1]), True)\n        check_and_update_peak_cpu_memory_usage(i, True)\n\ndef eval_mezo2_sgd(model, args, modelConfig, device='cuda:0'):\n    seed_everything(args.seed)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    print(\"Init dataset\")\n    input_ids, pos, labels = prepare_data(modelConfig.vocab_size, args.batch_size, modelConfig.block_size, device=device)\n    torch.cuda.reset_peak_memory_stats()\n    reset_peak_cpu_memory_usage()\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        model(input_ids, pos, labels)\n        check_peak_gpu_memory_usage(i, int(device[-1]), True)\n        check_and_update_peak_cpu_memory_usage(i, True)\n\ndef test_mezo_sgd_training():\n    seed_everything(args.seed)\n    cfgs = GPTConfigs()\n    cfg = getattr(cfgs, args.model_id)\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    torch.set_default_dtype(args.model_dtype)\n    model_mezo = get_nanogpt_mezo_sgd(zo_cfg)(cfg, zo_cfg).to(args.working_device)\n    torch.set_default_dtype(original_dtype)\n    train_mezo_sgd(model=model_mezo, \n               args=args, \n               modelConfig=cfg, \n               device=args.working_device)\n\ndef test_mezo2_sgd_training():\n    seed_everything(args.seed)\n    cfgs = GPTConfigs()\n    cfg = getattr(cfgs, args.model_id)\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    torch.set_default_dtype(args.model_dtype)\n    model = get_nanogpt_mezo_sgd(zo_cfg)(cfg, zo_cfg)\n    torch.set_default_dtype(original_dtype)\n    train_mezo2_sgd(model=model, \n                          args=args, \n                          modelConfig=cfg, \n                          device=args.working_device)\n\ndef test_mezo_sgd_eval():\n    seed_everything(args.seed)\n    cfgs = GPTConfigs()\n    cfg = getattr(cfgs, args.model_id)\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    torch.set_default_dtype(args.model_dtype)\n    model_mezo = get_nanogpt_mezo_sgd(zo_cfg)(cfg, zo_cfg).to(args.working_device)\n    torch.set_default_dtype(original_dtype)\n    eval_mezo_sgd(model=model_mezo, \n               args=args, \n               modelConfig=cfg, \n               device=args.working_device)\n\ndef test_mezo2_sgd_eval():\n    seed_everything(args.seed)\n    cfgs = GPTConfigs()\n    cfg = getattr(cfgs, args.model_id)\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    torch.set_default_dtype(args.model_dtype)\n    model = get_nanogpt_mezo_sgd(zo_cfg)(cfg, zo_cfg)\n    torch.set_default_dtype(original_dtype)\n    eval_mezo2_sgd(model=model, \n                          args=args, \n                          modelConfig=cfg, \n                          device=args.working_device)\n\n\nif __name__ == \"__main__\":\n    args = get_args()\n    original_dtype = torch.get_default_dtype()\n    if args.zo_method == \"zo\":\n        if args.eval:\n            test_mezo_sgd_eval()\n        else:\n            test_mezo_sgd_training()\n    elif args.zo_method == \"zo2\":\n        if args.eval:\n            test_mezo2_sgd_eval()\n        else:\n            test_mezo2_sgd_training()\n    else:\n        raise NotImplementedError"
  },
  {
    "path": "test/mezo_sgd/nanogpt/test_memory_eval.sh",
    "content": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_ids=(\"gpt2\" \"gpt2_medium\" \"gpt2_large\" \"gpt2_xl\" \"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_30b\" \"opt_66b\" \"opt_175b\")\n\n# ANSI color codes\nGREEN='\\033[0;32m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nfor model_id in \"${model_ids[@]}\"\ndo\n    echo \"Testing model_id: $model_id\"\n    \n    CMD1=\"python test/mezo_sgd/nanogpt/test_memory.py --model_id $model_id --zo_method zo --max_steps 30 --eval\"\n    CMD2=\"python test/mezo_sgd/nanogpt/test_memory.py --model_id $model_id --zo_method zo2 --max_steps 30 --eval\"\n\n    OUT1=\"/tmp/output1_$model_id.txt\"\n    OUT2=\"/tmp/output2_$model_id.txt\"\n\n    $CMD1 2>&1 | tee $OUT1\n    $CMD2 2>&1 | tee $OUT2\n\n    echo \"Analyzing Peak GPU Memory usage...\"\n    max_mem1=$(grep 'Peak GPU Memory' $OUT1 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1)\n    max_mem2=$(grep 'Peak GPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1)\n\n    if [ -z \"$max_mem1\" ] || [ -z \"$max_mem2\" ]; then\n        echo \"Could not find memory usage data in the output.\"\n    else\n        ratio=$(echo \"scale=2; $max_mem2 / $max_mem1 * 100\" | bc)\n        echo -e \"Model: $model_name\"\n        echo -e \"ZO peak GPU memory: ${GREEN}$max_mem1 MB${NC}\"\n        echo -e \"ZO2 peak GPU memory: ${GREEN}$max_mem2 MB${NC}\"\n        echo -e \"Memory usage ratio of ZO2 to ZO: ${GREEN}$ratio%${NC}\"\n    fi\n\n    rm $OUT1 $OUT2\ndone"
  },
  {
    "path": "test/mezo_sgd/nanogpt/test_memory_train.sh",
    "content": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_ids=(\"gpt2\" \"gpt2_medium\" \"gpt2_large\" \"gpt2_xl\" \"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_30b\" \"opt_66b\" \"opt_175b\")\n\n# ANSI color codes\nGREEN='\\033[0;32m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nfor model_id in \"${model_ids[@]}\"\ndo\n    echo \"Testing model_id: $model_id\"\n    \n    CMD1=\"python test/mezo_sgd/nanogpt/test_memory.py --model_id $model_id --zo_method zo --max_steps 30\"\n    CMD2=\"python test/mezo_sgd/nanogpt/test_memory.py --model_id $model_id --zo_method zo2 --max_steps 30\"\n\n    OUT1=\"/tmp/output1_$model_id.txt\"\n    OUT2=\"/tmp/output2_$model_id.txt\"\n\n    $CMD1 2>&1 | tee $OUT1\n    $CMD2 2>&1 | tee $OUT2\n\n    echo \"Analyzing Peak GPU Memory usage...\"\n    max_mem1=$(grep 'Peak GPU Memory' $OUT1 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1)\n    max_mem2=$(grep 'Peak GPU Memory' $OUT2 | awk '{print $7}' | sed 's/ MB//' | sort -nr | head -1)\n\n    if [ -z \"$max_mem1\" ] || [ -z \"$max_mem2\" ]; then\n        echo \"Could not find memory usage data in the output.\"\n    else\n        ratio=$(echo \"scale=2; $max_mem2 / $max_mem1 * 100\" | bc)\n        echo -e \"Model: $model_name\"\n        echo -e \"ZO peak GPU memory: ${GREEN}$max_mem1 MB${NC}\"\n        echo -e \"ZO2 peak GPU memory: ${GREEN}$max_mem2 MB${NC}\"\n        echo -e \"Memory usage ratio of ZO2 to ZO: ${GREEN}$ratio%${NC}\"\n    fi\n\n    rm $OUT1 $OUT2\ndone"
  },
  {
    "path": "test/mezo_sgd/nanogpt/test_speed.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport sys\nsys.path.append(\"../zo2\")\n\nimport torch\nfrom tqdm import tqdm\n\nfrom zo2.config.mezo_sgd import MeZOSGDConfig\nfrom zo2.model.nanogpt.mezo_sgd import get_nanogpt_mezo_sgd\nfrom zo2.model.nanogpt.model import GPTConfig, GPTConfigs\nfrom zo2.utils.utils import seed_everything\nfrom utils import model_size, prepare_data, get_args, check_throughput\n\ndef train_mezo_sgd(model, args, modelConfig, device='cuda'):\n    seed_everything(args.seed)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    print(\"Init dataset\")\n    input_ids, pos, labels = prepare_data(modelConfig.vocab_size, args.batch_size, modelConfig.block_size, device=device)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        check_throughput(i, args.batch_size*modelConfig.block_size, model, input_ids, pos, labels, use_tqdm=True)\n\ndef train_mezo2_sgd(model, args, modelConfig, device='cuda'):\n    seed_everything(args.seed)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    print(\"Init dataset\")\n    input_ids, pos, labels = prepare_data(modelConfig.vocab_size, args.batch_size, modelConfig.block_size, device=device)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_train()\n        check_throughput(i, args.batch_size*modelConfig.block_size, model, input_ids, pos, labels, use_tqdm=True)\n\ndef eval_mezo_sgd(model, args, modelConfig, device='cuda'):\n    seed_everything(args.seed)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    print(\"Init dataset\")\n    input_ids, pos, labels = prepare_data(modelConfig.vocab_size, args.batch_size, modelConfig.block_size, device=device)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        check_throughput(i, args.batch_size*modelConfig.block_size, model, input_ids, pos, labels, use_tqdm=True)\n\ndef eval_mezo2_sgd(model, args, modelConfig, device='cuda'):\n    seed_everything(args.seed)\n    total_parameters = model_size(model)[\"total\"]\n    print(f\"model size: {total_parameters/1024**3:.2f} B\")\n    print(\"Init dataset\")\n    input_ids, pos, labels = prepare_data(modelConfig.vocab_size, args.batch_size, modelConfig.block_size, device=device)\n    for i in tqdm(range(args.max_steps)):\n        model.zo_eval()\n        check_throughput(i, args.batch_size*modelConfig.block_size, model, input_ids, pos, labels, use_tqdm=True)\n\ndef test_mezo_sgd_training():\n    seed_everything(args.seed)\n    cfgs = GPTConfigs()\n    cfg = getattr(cfgs, args.model_id)\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    torch.set_default_dtype(args.model_dtype)\n    model_mezo = get_nanogpt_mezo_sgd(zo_cfg)(cfg, zo_cfg).to(args.working_device)\n    torch.set_default_dtype(original_dtype)\n    train_mezo_sgd(model=model_mezo, \n               args=args, \n               modelConfig=cfg, \n               device=args.working_device)\n\ndef test_mezo2_sgd_training():\n    seed_everything(args.seed)\n    cfgs = GPTConfigs()\n    cfg = getattr(cfgs, args.model_id)\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    torch.set_default_dtype(args.model_dtype)\n    model = get_nanogpt_mezo_sgd(zo_cfg)(cfg, zo_cfg)\n    torch.set_default_dtype(original_dtype)\n    train_mezo2_sgd(model=model, \n                          args=args, \n                          modelConfig=cfg, \n                          device=args.working_device)\n\ndef test_mezo_sgd_eval():\n    seed_everything(args.seed)\n    cfgs = GPTConfigs()\n    cfg = getattr(cfgs, args.model_id)\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        working_device=args.working_device)\n    zo_cfg.zo2 = False\n    torch.set_default_dtype(args.model_dtype)\n    model_mezo = get_nanogpt_mezo_sgd(zo_cfg)(cfg, zo_cfg).to(args.working_device)\n    torch.set_default_dtype(original_dtype)\n    eval_mezo_sgd(model=model_mezo, \n               args=args, \n               modelConfig=cfg, \n               device=args.working_device)\n\ndef test_mezo2_sgd_eval():\n    seed_everything(args.seed)\n    cfgs = GPTConfigs()\n    cfg = getattr(cfgs, args.model_id)\n    zo_cfg = MeZOSGDConfig(lr=args.lr, weight_decay=args.weight_decay, eps=args.zo_eps,\n        offloading_device=args.offloading_device, working_device=args.working_device)\n    zo_cfg.zo2 = True\n    torch.set_default_dtype(args.model_dtype)\n    model = get_nanogpt_mezo_sgd(zo_cfg)(cfg, zo_cfg)\n    torch.set_default_dtype(original_dtype)\n    eval_mezo2_sgd(model=model, \n                          args=args, \n                          modelConfig=cfg, \n                          device=args.working_device)\n\n\nif __name__ == \"__main__\":\n    args = get_args()\n    original_dtype = torch.get_default_dtype()\n    if args.zo_method == \"zo\":\n        if args.eval:\n            test_mezo_sgd_eval()\n        else:\n            test_mezo_sgd_training()\n    elif args.zo_method == \"zo2\":\n        if args.eval:\n            test_mezo2_sgd_eval()\n        else:\n            test_mezo2_sgd_training()\n    else:\n        raise NotImplementedError"
  },
  {
    "path": "test/mezo_sgd/nanogpt/test_speed_eval.sh",
    "content": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_ids=(\"gpt2\" \"gpt2_medium\" \"gpt2_large\" \"gpt2_xl\" \"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_30b\" \"opt_66b\" \"opt_175b\")\n\n# ANSI color codes\nGREEN='\\033[0;32m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nfor model_id in \"${model_ids[@]}\"\ndo\n    echo \"Testing model_id: $model_id\"\n    \n    CMD1=\"python test/mezo_sgd/nanogpt/test_speed.py --model_id $model_id --zo_method zo --max_steps 30 --eval\"\n    CMD2=\"python test/mezo_sgd/nanogpt/test_speed.py --model_id $model_id --zo_method zo2 --max_steps 30 --eval\"\n\n    OUT1=\"/tmp/output1_$model_id.txt\"\n    OUT2=\"/tmp/output2_$model_id.txt\"\n\n    $CMD1 2>&1 | tee $OUT1\n    $CMD2 2>&1 | tee $OUT2\n\n    echo \"Analyzing throughput...\"\n    \n    # Count the total number of lines and determine the number of iteration lines\n    total_lines1=$(wc -l < $OUT1)\n    total_lines2=$(wc -l < $OUT2)\n    iter_lines1=$(grep -c 'Time cost after iteration' $OUT1)\n    iter_lines2=$(grep -c 'Time cost after iteration' $OUT2)\n\n    # Calculate the starting line for the last 50% of iterations\n    start_line1=$(($total_lines1 - $iter_lines1 + $(($iter_lines1 / 2 + 1))))\n    start_line2=$(($total_lines2 - $iter_lines2 + $(($iter_lines2 / 2 + 1))))\n\n    # Calculate average tokens per second for the last 50% of the iterations\n    avg_tok_s1=$(tail -n +$start_line1 $OUT1 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}')\n    avg_tok_s2=$(tail -n +$start_line2 $OUT2 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}')\n\n    ratio=$(echo \"scale=2; $avg_tok_s2 / $avg_tok_s1 * 100\" | bc)\n\n    echo -e \"Model: $model_name\"\n    echo -e \"ZO average throughput (last 50% iterations): ${GREEN}$avg_tok_s1 tok/s${NC}\"\n    echo -e \"ZO2 average throughput (last 50% iterations): ${GREEN}$avg_tok_s2 tok/s${NC}\"\n    echo -e \"Throughput ratio of ZO2 to ZO (last 50% iterations): ${GREEN}$ratio%${NC}\"\n\n    rm $OUT1 $OUT2\ndone"
  },
  {
    "path": "test/mezo_sgd/nanogpt/test_speed_train.sh",
    "content": "#!/bin/bash\n\nset -e\nset -o pipefail\n\nmodel_ids=(\"gpt2\" \"gpt2_medium\" \"gpt2_large\" \"gpt2_xl\" \"opt_125m\" \"opt_350m\" \"opt_1_3b\" \"opt_2_7b\" \"opt_6_7b\" \"opt_13b\" \"opt_30b\" \"opt_66b\" \"opt_175b\")\n\n# ANSI color codes\nGREEN='\\033[0;32m'\nRED='\\033[0;31m'\nNC='\\033[0m'\n\nfor model_id in \"${model_ids[@]}\"\ndo\n    echo \"Testing model_id: $model_id\"\n    \n    CMD1=\"python test/mezo_sgd/nanogpt/test_speed.py --model_id $model_id --zo_method zo --max_steps 30\"\n    CMD2=\"python test/mezo_sgd/nanogpt/test_speed.py --model_id $model_id --zo_method zo2 --max_steps 30\"\n\n    OUT1=\"/tmp/output1_$model_id.txt\"\n    OUT2=\"/tmp/output2_$model_id.txt\"\n\n    $CMD1 2>&1 | tee $OUT1\n    $CMD2 2>&1 | tee $OUT2\n\n    echo \"Analyzing throughput...\"\n    \n    # Count the total number of lines and determine the number of iteration lines\n    total_lines1=$(wc -l < $OUT1)\n    total_lines2=$(wc -l < $OUT2)\n    iter_lines1=$(grep -c 'Time cost after iteration' $OUT1)\n    iter_lines2=$(grep -c 'Time cost after iteration' $OUT2)\n\n    # Calculate the starting line for the last 50% of iterations\n    start_line1=$(($total_lines1 - $iter_lines1 + $(($iter_lines1 / 2 + 1))))\n    start_line2=$(($total_lines2 - $iter_lines2 + $(($iter_lines2 / 2 + 1))))\n\n    # Calculate average tokens per second for the last 50% of the iterations\n    avg_tok_s1=$(tail -n +$start_line1 $OUT1 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}')\n    avg_tok_s2=$(tail -n +$start_line2 $OUT2 | grep 'tok/s' | awk '{print $8}' | awk '{total += $1; count++} END {print total/count}')\n\n    ratio=$(echo \"scale=2; $avg_tok_s2 / $avg_tok_s1 * 100\" | bc)\n\n    echo -e \"Model: $model_name\"\n    echo -e \"ZO average throughput (last 50% iterations): ${GREEN}$avg_tok_s1 tok/s${NC}\"\n    echo -e \"ZO2 average throughput (last 50% iterations): ${GREEN}$avg_tok_s2 tok/s${NC}\"\n    echo -e \"Throughput ratio of ZO2 to ZO (last 50% iterations): ${GREEN}$ratio%${NC}\"\n\n    rm $OUT1 $OUT2\ndone"
  },
  {
    "path": "test/mezo_sgd/nanogpt/utils.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\nimport time\nimport argparse\nfrom tqdm import tqdm\nimport psutil\nimport os\nimport pynvml\n\n\ndef get_args():\n    args = argparse.ArgumentParser()\n    args.add_argument(\"--zo_method\", type=str, default=\"zo2\")\n    args.add_argument(\"--eval\", action=\"store_true\")\n    args.add_argument(\"--model_id\", type=str, default=\"gpt2\")\n    args.add_argument(\"--model_dtype\", type=str, default=\"fp32\")\n    args.add_argument(\"--verbose\", action=\"store_true\")\n    args.add_argument(\"--max_steps\", type=int, default=3)\n    args.add_argument(\"--lr\", type=float, default=1e-4)\n    args.add_argument(\"--weight_decay\", type=float, default=1e-1)\n    args.add_argument(\"--zo_eps\", type=float, default=1e-3)\n    args.add_argument(\"--seed\", type=int, default=42)\n    args.add_argument(\"--batch_size\", type=int, default=1)\n    args.add_argument(\"--offloading_device\", type=str, default=\"cpu\")\n    args.add_argument(\"--working_device\", type=str, default=\"cuda:0\")\n    args = args.parse_args()\n    args.model_dtype = dtype_lookup[args.model_dtype]\n    return args\n\n\ndtype_lookup = {\n    \"fp64\": torch.float64,\n    \"fp32\": torch.float32,\n    \"fp16\": torch.float16,\n    \"bf16\": torch.bfloat16\n}\n\n\ndef model_size(model: torch.nn.Module):\n    total_size = sum(p.numel() for p in model.parameters())\n    trainable_size = sum(p.numel() for p in model.parameters() if p.requires_grad)\n    return {\"total\": total_size, \"trainable\": trainable_size}\n\n\ndef prepare_data(V, B, T, device='cuda'):\n    data_batch = torch.randint(0, V, (B, T+1)).to(device)\n    input_ids = data_batch[:, :T]\n    labels = data_batch[:, 1:T+1]\n    pos = torch.arange(input_ids.shape[1], dtype=torch.long, device=device).unsqueeze(0)\n    return input_ids, pos, labels\n\n\n# GPU Memory Monitoring\npynvml.nvmlInit()\ndef check_peak_gpu_memory_usage(iter, device=0, use_tqdm=False):\n    # Check the peak memory usage\n    handle = pynvml.nvmlDeviceGetHandleByIndex(device)  # Adjust index if necessary\n    info = pynvml.nvmlDeviceGetMemoryInfo(handle)\n    peak_memory = info.used / 1024**2\n    if use_tqdm:\n        tqdm.write(\"Peak GPU Memory after iteration {}: {:.2f} MB\".format(iter+1, peak_memory))\n    else:\n        print(f\"Peak GPU Memory after iteration {iter+1}: {peak_memory:.2f} MB\")\n\n# CPU Memory Monitoring\npeak_memory_cpu = 0\ndef check_and_update_peak_cpu_memory_usage(iter, use_tqdm=False):\n    global peak_memory_cpu\n    process = psutil.Process(os.getpid())\n    current_memory = process.memory_info().rss / (1024 ** 2)  # Convert to MB\n    if current_memory > peak_memory_cpu:\n        peak_memory_cpu = current_memory\n    if use_tqdm:\n        tqdm.write(f\"Peak CPU Memory after iteration {iter+1}: {peak_memory_cpu:.2f} MB\")\n    else:\n        print(f\"Peak CPU Memory after iteration {iter+1}: {peak_memory_cpu:.2f} MB\")\n\ndef reset_peak_cpu_memory_usage():\n    global peak_memory_cpu\n    peak_memory_cpu = 0\n    if torch.cuda.is_available():\n        torch.cuda.reset_peak_memory_stats()\n\ndef check_throughput(iter, total_token_batch_size_per_iter, fn, *args, use_tqdm=False, **kwargs):\n    t1 = time.time()\n    out = fn(*args, **kwargs)\n    t2 = time.time()\n    time_cost = t2-t1\n    throughtput = total_token_batch_size_per_iter / time_cost\n    if use_tqdm:\n        tqdm.write(\"Time cost after iteration {}: {:.2f} ms, {:.2f} tok/s\".format(iter+1, time_cost*1e3, throughtput))\n    else:\n        print(\"Time cost after iteration {}: {:.2f} ms, {:.2f} tok/s\".format(iter+1, time_cost*1e3, throughtput))\n"
  },
  {
    "path": "tutorial/README.md",
    "content": "# API of ZO2\n\nWelcome to the ZO2 API documentation!\n\n## Standard Usage\n\n### 1. Quick Start\n\nFor a straightforward introduction to using ZO2, refer to the Jupyter notebook: [demo.ipynb](./demo.ipynb)\n\n### 2. Huggingface Trainer\n\nTo see how ZO2 can be integrated with the Huggingface Trainer for efficient model training, check out: [huggingface.ipynb](./huggingface.ipynb)\n\n## 3. Extend ZO2 to Your Own PyTorch Models\n\nLearn how to apply ZO2 to your own PyTorch models by following the example of building a nanogpt model: [nanogpt.ipynb](./nanogpt.ipynb).\n"
  },
  {
    "path": "tutorial/colab.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Environment Setting\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"# Set CUDA_VISIBLE_DEVICES to 0 to make only the first GPU visible\\n\",\n    \"os.environ['CUDA_VISIBLE_DEVICES'] = '0'\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"!pip install -q condacolab\\n\",\n    \"import condacolab\\n\",\n    \"condacolab.install()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import condacolab\\n\",\n    \"condacolab.check()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"\\n\",\n    \"!rm -rf zo2/\\n\",\n    \"!git clone https://github.com/liangyuwang/zo2.git\\n\",\n    \"print(\\\"Current working directory:\\\", os.getcwd())\\n\",\n    \"os.chdir('zo2/')\\n\",\n    \"print(\\\"New working directory:\\\", os.getcwd())\\n\",\n    \"\\n\",\n    \"!conda env update -n base -f env.yml\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using [MeZO Runner](../example/mezo_runner/) on Supported Tasks\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"\\n\",\n    \"print(\\\"Current working directory:\\\", os.getcwd())\\n\",\n    \"os.chdir('./example/mezo_runner/')\\n\",\n    \"print(\\\"New working directory:\\\", os.getcwd())\\n\",\n    \"\\n\",\n    \"!MODEL=facebook/opt-2.7b TASK=SST2 MODE=ft LR=1e-7 EPS=1e-3 STEPS=20000 EVAL_STEPS=4000 bash mezo.sh\\n\",\n    \"\\n\",\n    \"os.chdir('../../tutorial/')\\n\",\n    \"print(\\\"New working directory:\\\", os.getcwd())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using Huggingface Trainer\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import sys\\n\",\n    \"sys.path.append(\\\"../\\\")\\n\",\n    \"\\n\",\n    \"from tqdm.auto import tqdm\\n\",\n    \"import torch\\n\",\n    \"from transformers import (\\n\",\n    \"    AutoTokenizer, \\n\",\n    \"    TrainingArguments,\\n\",\n    \"    DataCollatorForLanguageModeling\\n\",\n    \")\\n\",\n    \"from datasets import load_dataset\\n\",\n    \"from zo2 import (\\n\",\n    \"    ZOConfig,\\n\",\n    \"    zo_hf_init,\\n\",\n    \")\\n\",\n    \"from zo2.trainer.hf_transformers.trainer import ZOTrainer\\n\",\n    \"from zo2.trainer.hf_trl.sft_trainer import ZOSFTTrainer\\n\",\n    \"from zo2.utils import seed_everything\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Hyperparameter\\n\",\n    \"zo_method = \\\"zo2\\\"\\n\",\n    \"eval_mode = False\\n\",\n    \"model_name = \\\"facebook/opt-2.7b\\\"\\n\",\n    \"verbose = True\\n\",\n    \"max_steps = 300\\n\",\n    \"learning_rate = 1e-7\\n\",\n    \"weight_decay = 1e-1\\n\",\n    \"zo_eps = 1e-3\\n\",\n    \"seed = 42\\n\",\n    \"offloading_device = \\\"cpu\\\"\\n\",\n    \"working_device = \\\"cuda:0\\\"\\n\",\n    \"max_train_data = None\\n\",\n    \"max_eval_data = None\\n\",\n    \"use_cache = True\\n\",\n    \"max_new_tokens = 50\\n\",\n    \"temperature = 1.0\\n\",\n    \"seed_everything(seed)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# ZO steps\\n\",\n    \"zo_config = ZOConfig(\\n\",\n    \"    method=\\\"mezo-sgd\\\", \\n\",\n    \"    zo2=zo_method==\\\"zo2\\\", \\n\",\n    \"    lr=learning_rate,\\n\",\n    \"    weight_decay=weight_decay,\\n\",\n    \"    eps=zo_eps,\\n\",\n    \"    offloading_device=offloading_device,\\n\",\n    \"    working_device=working_device,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# Load ZO model\\n\",\n    \"with zo_hf_init(zo_config):\\n\",\n    \"    from transformers import OPTForCausalLM\\n\",\n    \"    model = OPTForCausalLM.from_pretrained(model_name)\\n\",\n    \"    model.zo_init(zo_config)\\n\",\n    \"if zo_method != \\\"zo2\\\": \\n\",\n    \"    model = model.to(working_device)\\n\",\n    \"print(f\\\"Check if zo2 init correctly: {hasattr(model, 'zo_training')}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Prepare dataset\\n\",\n    \"dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')\\n\",\n    \"\\n\",\n    \"# tokenizing dataset\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(model_name)\\n\",\n    \"block_size = tokenizer.model_max_length\\n\",\n    \"def tokenize_function(examples):\\n\",\n    \"    return tokenizer(examples[\\\"text\\\"])\\n\",\n    \"tokenized_datasets = dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=[\\\"text\\\"])\\n\",\n    \"def group_texts(examples):\\n\",\n    \"    # Concatenate all texts.\\n\",\n    \"    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\\n\",\n    \"    total_length = len(concatenated_examples[list(examples.keys())[0]])\\n\",\n    \"    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\\n\",\n    \"        # customize this part to your needs.\\n\",\n    \"    total_length = (total_length // block_size) * block_size\\n\",\n    \"    # Split by chunks of max_len.\\n\",\n    \"    result = {\\n\",\n    \"        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\\n\",\n    \"        for k, t in concatenated_examples.items()\\n\",\n    \"    }\\n\",\n    \"    result[\\\"labels\\\"] = result[\\\"input_ids\\\"].copy()\\n\",\n    \"    return result\\n\",\n    \"lm_datasets = tokenized_datasets.map(\\n\",\n    \"    group_texts,\\n\",\n    \"    batched=True,\\n\",\n    \"    batch_size=1000,\\n\",\n    \"    num_proc=4,\\n\",\n    \")\\n\",\n    \"data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# trainer init\\n\",\n    \"training_args = TrainingArguments(\\n\",\n    \"    \\\"test-trainer\\\", \\n\",\n    \"    max_steps=max_steps,\\n\",\n    \"    save_strategy=\\\"no\\\", \\n\",\n    \"    logging_steps=10,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"trainer = ZOTrainer(\\n\",\n    \"    model,\\n\",\n    \"    training_args,\\n\",\n    \"    train_dataset=tokenized_datasets[\\\"train\\\"],\\n\",\n    \"    data_collator=data_collator,\\n\",\n    \"    processing_class=tokenizer,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# 'ZOTrainer' provides the capability to register pre-hooks and post-hooks during zo_step\\n\",\n    \"def drop_invalid_data(model, inputs, loss):\\n\",\n    \"    # Extract projected_grad, handle both tensor and scalar cases\\n\",\n    \"    projected_grad = model.opt.projected_grad\\n\",\n    \"    if isinstance(projected_grad, torch.Tensor):\\n\",\n    \"        projected_grad_is_nan = torch.isnan(projected_grad).any()\\n\",\n    \"    else:\\n\",\n    \"        projected_grad_is_nan = projected_grad != projected_grad  # Check for NaN in scalars\\n\",\n    \"    if torch.isnan(loss) or projected_grad_is_nan:\\n\",\n    \"        tqdm.write(\\\"'loss': {} or 'projected_grad': {} is nan. Drop this step.\\\".format(\\n\",\n    \"            loss, model.opt.projected_grad\\n\",\n    \"        ))\\n\",\n    \"        model.opt.projected_grad = 0  # Reset projected_grad to prevent parameter updates\\n\",\n    \"    return model, inputs, loss\\n\",\n    \"trainer.register_zo2_training_step_post_hook(drop_invalid_data)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# trainer step\\n\",\n    \"trainer.train()\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"language_info\": {\n   \"name\": \"python\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "tutorial/demo.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Fine-tune HF Model with Your Custom Training Loop\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import sys\\n\",\n    \"sys.path.append(\\\"../\\\")\\n\",\n    \"\\n\",\n    \"from tqdm.auto import tqdm\\n\",\n    \"import torch\\n\",\n    \"from transformers import AutoTokenizer\\n\",\n    \"from zo2 import (\\n\",\n    \"    ZOConfig,\\n\",\n    \"    zo_hf_init,\\n\",\n    \")\\n\",\n    \"from zo2.utils import seed_everything\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Hyperparameter\\n\",\n    \"zo_method = \\\"zo2\\\"\\n\",\n    \"eval_mode = False\\n\",\n    \"model_name = \\\"facebook/opt-2.7b\\\"\\n\",\n    \"verbose = True\\n\",\n    \"max_steps = 100\\n\",\n    \"learning_rate = 1e-5\\n\",\n    \"weight_decay = 1e-1\\n\",\n    \"zo_eps = 1e-3\\n\",\n    \"seed = 42\\n\",\n    \"offloading_device = \\\"cpu\\\"\\n\",\n    \"working_device = \\\"cuda:0\\\"\\n\",\n    \"use_cache = True\\n\",\n    \"max_new_tokens = 50\\n\",\n    \"temperature = 1.0\\n\",\n    \"seed_everything(seed)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# ZO steps\\n\",\n    \"zo_config = ZOConfig(\\n\",\n    \"    method=\\\"mezo-sgd\\\", \\n\",\n    \"    zo2=zo_method==\\\"zo2\\\", \\n\",\n    \"    lr=learning_rate,\\n\",\n    \"    weight_decay=weight_decay,\\n\",\n    \"    eps=zo_eps,\\n\",\n    \"    offloading_device=offloading_device,\\n\",\n    \"    working_device=working_device,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# Load ZO model\\n\",\n    \"with zo_hf_init(zo_config):\\n\",\n    \"    from transformers import OPTForCausalLM\\n\",\n    \"    model = OPTForCausalLM.from_pretrained(model_name)\\n\",\n    \"    model.zo_init(zo_config)\\n\",\n    \"if zo_method != \\\"zo2\\\": \\n\",\n    \"    model = model.to(working_device)\\n\",\n    \"print(f\\\"Check if zo2 init correctly: {hasattr(model, 'zo_training')}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Prepare some data\\n\",\n    \"dataset = \\\"\\\"\\\"\\n\",\n    \"    What is ZO2? \\n\",\n    \"    ZO2 is an innovative framework specifically designed to enhance the fine-tuning of large language models (LLMs) using zeroth-order (ZO) optimization techniques and advanced offloading technologies. \\n\",\n    \"    This framework is particularly tailored for setups with limited GPU memory, enabling the fine-tuning of models that were previously unmanageable due to hardware constraints. \\n\",\n    \"    As the scale of Large Language Models (LLMs) continues to grow, reaching parameter counts in the hundreds of billions, managing GPU memory resources effectively becomes crucial. \\n\",\n    \"    Efficient GPU memory management is crucial not only because it directly influences model performance and training speed, but also because GPU memory is both expensive and limited in quantity. \\n\",\n    \"    However, this creates a significant challenge in handling ever-larger models within the physical constraints of current hardware technologies. \\n\",\n    \"    CPU offloading has become a crucial technique for overcoming this challenge. \\n\",\n    \"    It involves transferring computations and data from the GPU to the CPU, specifically targeting data or parameters that are less frequently accessed. \\n\",\n    \"    By offloading these inactive tensors of the neural network, CPU offloading effectively alleviates the memory and computational pressures on GPUs. \\n\",\n    \"    While CPU offloading has been commonly applied in inference to manage memory-intensive tasks, its application in training, especially fine-tuning, remains less explored. \\n\",\n    \"    Recently, some works have tried to introduce CPU offloading into LLM training. \\n\",\n    \"    However, they are typically constrained by the capabilities of first-order optimizers such as SGD and Adaptive Moment Estimation (AdamW), and limited GPU memory, restricting large-scale model scalability on single GPU systems. \\n\",\n    \"    Using first-order optimizers introduces inefficiencies in CPU offloading: Multiple communication operations during the training of LLMs necessitate offloading the same data twice—once for each pass. \\n\",\n    \"    This redundancy not only doubles the communication volume between the CPU and GPU but also introduces significant latency due to repetitive data transfers. \\n\",\n    \"    Furthermore, both parameters and activations are required in the backward pass to complete gradient computations. \\n\",\n    \"    This means that parameters and activation values must be offloaded during each forward pass and re-uploaded to the GPU for the backward pass, increasing the volume of data transferred, which severely impacts training throughput. \\n\",\n    \"    On the other hand, zeroth-order (ZO) methods offer a novel approach to fine-tuning LLMs. \\n\",\n    \"    These methods utilize dual forward passes to estimate parameter gradients and subsequently update parameters. \\n\",\n    \"    This approach eliminates the traditional reliance on backward passes, thereby streamlining the training process by significantly reducing the number of computational steps required. \\n\",\n    \"    Based on these observations, we conjecture that ZO's architecture is particularly well-suited for CPU offloading strategies. \\n\",\n    \"    By eliminating backward passes and the need to store activation values, it can significantly reduce GPU memory demands through efficient parameter offloading. \\n\",\n    \"    However, despite these advantages, ZO training via CPU offloading introduces new challenges, particularly in the realm of CPU-to-GPU communication. \\n\",\n    \"    Transferring parameters between the CPU and GPU, which is crucial for maintaining gradient computation and model updates, becomes a critical bottleneck. \\n\",\n    \"    Although ZO methods inherently extend computation times because of the dual forward passes, potentially allowing for better overlap between computation and communication, there remain significant inefficiencies. \\n\",\n    \"    The necessity to upload parameters to the GPU for upcoming computations introduces a large volume of communications. To tackle the inefficiencies highlighted, we introduce ZO2, a novel framework specifically designed for ZO fine-tuning in LLMs with CPU offloading. \\n\",\n    \"    This framework utilizes the unique dual forward pass architecture of ZO methods to optimize interactions between CPU and GPU, significantly enhancing both computational and communication efficiency. \\n\",\n    \"    By building a high-performance dynamic scheduler, ZO2 achieves substantial overlaps in communication and computation. \\n\",\n    \"    These innovations make it feasible to fine-tune extremely large models, such as the OPT-175B, with over 175 billion parameters, on a single GPU equipped with just 18GB of memory usage—a capability previously unattainable with conventional methods. \\n\",\n    \"    Additionally, our efficient framework operates without any extra time cost and decreases in accuracy compared to standard ZO methodologies.\\\"\\\"\\\"\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(model_name)\\n\",\n    \"data_batch = tokenizer(dataset, add_special_tokens=True, return_tensors='pt').input_ids.to(working_device)\\n\",\n    \"T = min(data_batch.shape[1] - 1, model.config.max_position_embeddings)\\n\",\n    \"print(f\\\"Fine-tuning model {model_name} with {T} tokens dataset: \\\\n{dataset}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Training loop\\n\",\n    \"for i in tqdm(range(max_steps)):\\n\",\n    \"    model.zo_train()\\n\",\n    \"    loss = model(input_ids=data_batch, labels=data_batch)\\n\",\n    \"\\n\",\n    \"    # eval\\n\",\n    \"    if eval_mode:\\n\",\n    \"        if i==0:\\n\",\n    \"            tqdm.write(\\\"Warning: please notice that ZO2 does not optimize the evaluation, so it may be very slow.\\\")\\n\",\n    \"        model.zo_eval()\\n\",\n    \"        output = model(input_ids=data_batch, labels=data_batch)\\n\",\n    \"        res = \\\"Iteration {}, train loss: {}, projected grad: {}, eval loss: {}\\\"\\n\",\n    \"        tqdm.write(res.format(i, loss, model.opt.projected_grad, output[\\\"loss\\\"]))\\n\",\n    \"    else:\\n\",\n    \"        res = \\\"Iteration {}, train loss: {}, projected grad: {}\\\"\\n\",\n    \"        tqdm.write(res.format(i, loss, model.opt.projected_grad))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# inference\\n\",\n    \"print(\\\"Doing inference...\\\")\\n\",\n    \"print(\\\"Warning: please notice that ZO2 does not optimize the inference, so it may be very slow.\\\")\\n\",\n    \"model.zo_eval()\\n\",\n    \"prompt = \\\"What is ZO2 and how ZO2 enhance the fine-tuning of large language models?\\\"\\n\",\n    \"inputs = tokenizer(prompt, return_tensors='pt').to(working_device)\\n\",\n    \"inputs = {\\\"input_ids\\\": inputs.input_ids}\\n\",\n    \"for _ in tqdm(range(max_new_tokens)):\\n\",\n    \"    outputs = model(**inputs, return_dict=True)\\n\",\n    \"    next_token_logits = outputs.logits[:, -1, :]\\n\",\n    \"    if temperature == 1.0:\\n\",\n    \"        next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)\\n\",\n    \"    else:\\n\",\n    \"        scaled_logits = next_token_logits / temperature\\n\",\n    \"        probs = torch.nn.functional.softmax(scaled_logits, dim=-1)\\n\",\n    \"        next_token = torch.multinomial(probs, num_samples=1)\\n\",\n    \"    inputs = torch.cat([inputs[\\\"input_ids\\\"], next_token], dim=-1)\\n\",\n    \"    generated_text = tokenizer.decode(inputs[0])\\n\",\n    \"    inputs = {\\\"input_ids\\\": inputs}\\n\",\n    \"print(f\\\"Question: {prompt}\\\")\\n\",\n    \"print(f\\\"Response: {generated_text[len(prompt)+4:]}...\\\")\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"mezo\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.9\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "tutorial/huggingface.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Environment Setting\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"# Set CUDA_VISIBLE_DEVICES to 0 to make only the first GPU visible\\n\",\n    \"os.environ['CUDA_VISIBLE_DEVICES'] = '0'\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using [MeZO Runner](../example/mezo_runner/) on Supported Tasks\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"\\n\",\n    \"print(\\\"Current working directory:\\\", os.getcwd())\\n\",\n    \"os.chdir('../example/mezo_runner/')\\n\",\n    \"print(\\\"New working directory:\\\", os.getcwd())\\n\",\n    \"\\n\",\n    \"!MODEL=facebook/opt-2.7b TASK=SST2 MODE=ft LR=1e-7 EPS=1e-3 STEPS=20000 EVAL_STEPS=4000 bash mezo.sh\\n\",\n    \"\\n\",\n    \"os.chdir('../../tutorial/')\\n\",\n    \"print(\\\"New working directory:\\\", os.getcwd())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Using Huggingface Trainer\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import sys\\n\",\n    \"sys.path.append(\\\"../\\\")\\n\",\n    \"\\n\",\n    \"from tqdm.auto import tqdm\\n\",\n    \"import torch\\n\",\n    \"from transformers import (\\n\",\n    \"    AutoTokenizer, \\n\",\n    \"    TrainingArguments,\\n\",\n    \"    DataCollatorForLanguageModeling\\n\",\n    \")\\n\",\n    \"from datasets import load_dataset\\n\",\n    \"from zo2 import (\\n\",\n    \"    ZOConfig,\\n\",\n    \"    zo_hf_init,\\n\",\n    \")\\n\",\n    \"from zo2.trainer.hf_transformers.trainer import ZOTrainer\\n\",\n    \"from zo2.trainer.hf_trl.sft_trainer import ZOSFTTrainer\\n\",\n    \"from zo2.utils import seed_everything\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Hyperparameter\\n\",\n    \"zo_method = \\\"zo2\\\"\\n\",\n    \"eval_mode = False\\n\",\n    \"model_name = \\\"facebook/opt-2.7b\\\"\\n\",\n    \"verbose = True\\n\",\n    \"max_steps = 300\\n\",\n    \"learning_rate = 1e-7\\n\",\n    \"weight_decay = 1e-1\\n\",\n    \"zo_eps = 1e-3\\n\",\n    \"seed = 42\\n\",\n    \"offloading_device = \\\"cpu\\\"\\n\",\n    \"working_device = \\\"cuda:0\\\"\\n\",\n    \"max_train_data = None\\n\",\n    \"max_eval_data = None\\n\",\n    \"use_cache = True\\n\",\n    \"max_new_tokens = 50\\n\",\n    \"temperature = 1.0\\n\",\n    \"seed_everything(seed)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# ZO steps\\n\",\n    \"zo_config = ZOConfig(\\n\",\n    \"    method=\\\"mezo-sgd\\\", \\n\",\n    \"    zo2=zo_method==\\\"zo2\\\", \\n\",\n    \"    lr=learning_rate,\\n\",\n    \"    weight_decay=weight_decay,\\n\",\n    \"    eps=zo_eps,\\n\",\n    \"    offloading_device=offloading_device,\\n\",\n    \"    working_device=working_device,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# Load ZO model\\n\",\n    \"with zo_hf_init(zo_config):\\n\",\n    \"    from transformers import OPTForCausalLM\\n\",\n    \"    model = OPTForCausalLM.from_pretrained(model_name)\\n\",\n    \"    model.zo_init(zo_config)\\n\",\n    \"if zo_method != \\\"zo2\\\": \\n\",\n    \"    model = model.to(working_device)\\n\",\n    \"print(f\\\"Check if zo2 init correctly: {hasattr(model, 'zo_training')}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Prepare dataset\\n\",\n    \"dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')\\n\",\n    \"\\n\",\n    \"# tokenizing dataset\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(model_name)\\n\",\n    \"block_size = tokenizer.model_max_length\\n\",\n    \"def tokenize_function(examples):\\n\",\n    \"    return tokenizer(examples[\\\"text\\\"])\\n\",\n    \"tokenized_datasets = dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=[\\\"text\\\"])\\n\",\n    \"def group_texts(examples):\\n\",\n    \"    # Concatenate all texts.\\n\",\n    \"    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\\n\",\n    \"    total_length = len(concatenated_examples[list(examples.keys())[0]])\\n\",\n    \"    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\\n\",\n    \"        # customize this part to your needs.\\n\",\n    \"    total_length = (total_length // block_size) * block_size\\n\",\n    \"    # Split by chunks of max_len.\\n\",\n    \"    result = {\\n\",\n    \"        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\\n\",\n    \"        for k, t in concatenated_examples.items()\\n\",\n    \"    }\\n\",\n    \"    result[\\\"labels\\\"] = result[\\\"input_ids\\\"].copy()\\n\",\n    \"    return result\\n\",\n    \"lm_datasets = tokenized_datasets.map(\\n\",\n    \"    group_texts,\\n\",\n    \"    batched=True,\\n\",\n    \"    batch_size=1000,\\n\",\n    \"    num_proc=4,\\n\",\n    \")\\n\",\n    \"data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# trainer init\\n\",\n    \"training_args = TrainingArguments(\\n\",\n    \"    \\\"test-trainer\\\", \\n\",\n    \"    max_steps=max_steps,\\n\",\n    \"    save_strategy=\\\"no\\\", \\n\",\n    \"    logging_steps=10,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"trainer = ZOTrainer(\\n\",\n    \"    model,\\n\",\n    \"    training_args,\\n\",\n    \"    train_dataset=tokenized_datasets[\\\"train\\\"],\\n\",\n    \"    data_collator=data_collator,\\n\",\n    \"    processing_class=tokenizer,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# 'ZOTrainer' provides the capability to register pre-hooks and post-hooks during zo_step\\n\",\n    \"def drop_invalid_data(model, inputs, loss):\\n\",\n    \"    # Extract projected_grad, handle both tensor and scalar cases\\n\",\n    \"    projected_grad = model.opt.projected_grad\\n\",\n    \"    if isinstance(projected_grad, torch.Tensor):\\n\",\n    \"        projected_grad_is_nan = torch.isnan(projected_grad).any()\\n\",\n    \"    else:\\n\",\n    \"        projected_grad_is_nan = projected_grad != projected_grad  # Check for NaN in scalars\\n\",\n    \"    if torch.isnan(loss) or projected_grad_is_nan:\\n\",\n    \"        tqdm.write(\\\"'loss': {} or 'projected_grad': {} is nan. Drop this step.\\\".format(\\n\",\n    \"            loss, model.opt.projected_grad\\n\",\n    \"        ))\\n\",\n    \"        model.opt.projected_grad = 0  # Reset projected_grad to prevent parameter updates\\n\",\n    \"    return model, inputs, loss\\n\",\n    \"trainer.register_zo2_training_step_post_hook(drop_invalid_data)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# trainer step\\n\",\n    \"trainer.train()\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"mezo\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.9\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "tutorial/nanogpt.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Environment Setting\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import sys\\n\",\n    \"sys.path.append(\\\"../\\\")\\n\",\n    \"\\n\",\n    \"import math\\n\",\n    \"import inspect\\n\",\n    \"from dataclasses import dataclass\\n\",\n    \"from tqdm.auto import tqdm\\n\",\n    \"\\n\",\n    \"import torch\\n\",\n    \"import torch.nn as nn\\n\",\n    \"from torch.nn import functional as F\\n\",\n    \"\\n\",\n    \"from zo2 import ZOConfig\\n\",\n    \"from zo2.model.base import BaseZOModel\\n\",\n    \"from zo2.optimizer.mezo_sgd.zo import MeZOSGD\\n\",\n    \"from zo2.optimizer.mezo_sgd.zo2 import MeZO2SGD\\n\",\n    \"from zo2.config.mezo_sgd import MeZOSGDConfig\\n\",\n    \"from zo2.utils import seed_everything\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Define Your Model \\n\",\n    \"Here we use NanoGPT model, copied from [nanogpt github](https://github.com/karpathy/build-nanogpt/blob/master/train_gpt2.py), as an example.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"@dataclass\\n\",\n    \"class GPTConfig:\\n\",\n    \"    block_size: int = 1024\\n\",\n    \"    vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency\\n\",\n    \"    n_layer: int = 12\\n\",\n    \"    n_head: int = 12\\n\",\n    \"    n_embd: int = 768\\n\",\n    \"    dropout: float = 0.0\\n\",\n    \"    bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster\\n\",\n    \"\\n\",\n    \"class LayerNorm(nn.Module):\\n\",\n    \"    \\\"\\\"\\\" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False \\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"    def __init__(self, ndim, bias):\\n\",\n    \"        super().__init__()\\n\",\n    \"        self.weight = nn.Parameter(torch.ones(ndim))\\n\",\n    \"        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None\\n\",\n    \"\\n\",\n    \"    def forward(self, input):\\n\",\n    \"        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)\\n\",\n    \"\\n\",\n    \"class CausalSelfAttention(nn.Module):\\n\",\n    \"\\n\",\n    \"    def __init__(self, config):\\n\",\n    \"        super().__init__()\\n\",\n    \"        assert config.n_embd % config.n_head == 0\\n\",\n    \"        # key, query, value projections for all heads, but in a batch\\n\",\n    \"        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)\\n\",\n    \"        # output projection\\n\",\n    \"        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)\\n\",\n    \"        # regularization\\n\",\n    \"        self.attn_dropout = nn.Dropout(config.dropout)\\n\",\n    \"        self.resid_dropout = nn.Dropout(config.dropout)\\n\",\n    \"        self.n_head = config.n_head\\n\",\n    \"        self.n_embd = config.n_embd\\n\",\n    \"        self.dropout = config.dropout\\n\",\n    \"        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0\\n\",\n    \"        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')\\n\",\n    \"        if not self.flash:\\n\",\n    \"            print(\\\"WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0\\\")\\n\",\n    \"            # causal mask to ensure that attention is only applied to the left in the input sequence\\n\",\n    \"            self.register_buffer(\\\"bias\\\", torch.tril(torch.ones(config.block_size, config.block_size))\\n\",\n    \"                                        .view(1, 1, config.block_size, config.block_size))\\n\",\n    \"\\n\",\n    \"    def forward(self, x):\\n\",\n    \"        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)\\n\",\n    \"\\n\",\n    \"        # calculate query, key, values for all heads in batch and move head forward to be the batch dim\\n\",\n    \"        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)\\n\",\n    \"        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)\\n\",\n    \"        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)\\n\",\n    \"        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)\\n\",\n    \"\\n\",\n    \"        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)\\n\",\n    \"        if self.flash:\\n\",\n    \"            # efficient attention using Flash Attention CUDA kernels\\n\",\n    \"            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)\\n\",\n    \"        else:\\n\",\n    \"            # manual implementation of attention\\n\",\n    \"            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))\\n\",\n    \"            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))\\n\",\n    \"            att = F.softmax(att, dim=-1)\\n\",\n    \"            att = self.attn_dropout(att)\\n\",\n    \"            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)\\n\",\n    \"        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side\\n\",\n    \"\\n\",\n    \"        # output projection\\n\",\n    \"        y = self.resid_dropout(self.c_proj(y))\\n\",\n    \"        return y\\n\",\n    \"\\n\",\n    \"class MLP(nn.Module):\\n\",\n    \"\\n\",\n    \"    def __init__(self, config):\\n\",\n    \"        super().__init__()\\n\",\n    \"        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)\\n\",\n    \"        self.gelu    = nn.GELU()\\n\",\n    \"        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)\\n\",\n    \"        self.dropout = nn.Dropout(config.dropout)\\n\",\n    \"\\n\",\n    \"    def forward(self, x):\\n\",\n    \"        x = self.c_fc(x)\\n\",\n    \"        x = self.gelu(x)\\n\",\n    \"        x = self.c_proj(x)\\n\",\n    \"        x = self.dropout(x)\\n\",\n    \"        return x\\n\",\n    \"\\n\",\n    \"class Block(nn.Module):\\n\",\n    \"\\n\",\n    \"    def __init__(self, config):\\n\",\n    \"        super().__init__()\\n\",\n    \"        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)\\n\",\n    \"        self.attn = CausalSelfAttention(config)\\n\",\n    \"        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)\\n\",\n    \"        self.mlp = MLP(config)\\n\",\n    \"\\n\",\n    \"    def forward(self, x):\\n\",\n    \"        x = x + self.attn(self.ln_1(x))\\n\",\n    \"        x = x + self.mlp(self.ln_2(x))\\n\",\n    \"        return x\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"class GPT(nn.Module):\\n\",\n    \"\\n\",\n    \"    def __init__(self, config):\\n\",\n    \"        super().__init__()\\n\",\n    \"        assert config.vocab_size is not None\\n\",\n    \"        assert config.block_size is not None\\n\",\n    \"        self.config = config\\n\",\n    \"\\n\",\n    \"        self.transformer = nn.ModuleDict(dict(\\n\",\n    \"            wte = nn.Embedding(config.vocab_size, config.n_embd),\\n\",\n    \"            wpe = nn.Embedding(config.block_size, config.n_embd),\\n\",\n    \"            drop = nn.Dropout(config.dropout),\\n\",\n    \"            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),\\n\",\n    \"            ln_f = LayerNorm(config.n_embd, bias=config.bias),\\n\",\n    \"        ))\\n\",\n    \"        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\\n\",\n    \"        # with weight tying when using torch.compile() some warnings get generated:\\n\",\n    \"        # \\\"UserWarning: functional_call was passed multiple values for tied weights.\\n\",\n    \"        # This behavior is deprecated and will be an error in future versions\\\"\\n\",\n    \"        # not 100% sure what this is, so far seems to be harmless. TODO investigate\\n\",\n    \"        # self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying\\n\",\n    \"\\n\",\n    \"        # init all weights\\n\",\n    \"        self.apply(self._init_weights)\\n\",\n    \"        # apply special scaled init to the residual projections, per GPT-2 paper\\n\",\n    \"        for pn, p in self.named_parameters():\\n\",\n    \"            if pn.endswith('c_proj.weight'):\\n\",\n    \"                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))\\n\",\n    \"\\n\",\n    \"        # report number of parameters\\n\",\n    \"        print(\\\"number of parameters: %.2fM\\\" % (self.get_num_params()/1e6,))\\n\",\n    \"\\n\",\n    \"    def get_num_params(self, non_embedding=True):\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        Return the number of parameters in the model.\\n\",\n    \"        For non-embedding count (default), the position embeddings get subtracted.\\n\",\n    \"        The token embeddings would too, except due to the parameter sharing these\\n\",\n    \"        params are actually used as weights in the final layer, so we include them.\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        n_params = sum(p.numel() for p in self.parameters())\\n\",\n    \"        if non_embedding:\\n\",\n    \"            n_params -= self.transformer.wpe.weight.numel()\\n\",\n    \"        return n_params\\n\",\n    \"\\n\",\n    \"    def _init_weights(self, module):\\n\",\n    \"        if isinstance(module, nn.Linear):\\n\",\n    \"            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\\n\",\n    \"            if module.bias is not None:\\n\",\n    \"                torch.nn.init.zeros_(module.bias)\\n\",\n    \"        elif isinstance(module, nn.Embedding):\\n\",\n    \"            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\\n\",\n    \"\\n\",\n    \"    def forward(self, idx, pos, targets=None):\\n\",\n    \"        # idx is of shape (B, T)\\n\",\n    \"        B, T = idx.size()\\n\",\n    \"        assert T <= self.config.block_size, f\\\"Cannot forward sequence of length {T}, block size is only {self.config.block_size}\\\"\\n\",\n    \"        # forward the token and posisition embeddings\\n\",\n    \"        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)\\n\",\n    \"        tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)\\n\",\n    \"        x = tok_emb + pos_emb\\n\",\n    \"        # forward the blocks of the transformer\\n\",\n    \"        for block in self.transformer.h:\\n\",\n    \"            x = block(x)\\n\",\n    \"        # forward the final layernorm and the classifier\\n\",\n    \"        x = self.transformer.ln_f(x)\\n\",\n    \"        logits = self.lm_head(x) # (B, T, vocab_size)\\n\",\n    \"        loss = None\\n\",\n    \"        if targets is not None:\\n\",\n    \"            loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))\\n\",\n    \"        return logits, loss\\n\",\n    \"\\n\",\n    \"    def crop_block_size(self, block_size):\\n\",\n    \"        # model surgery to decrease the block size if necessary\\n\",\n    \"        # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)\\n\",\n    \"        # but want to use a smaller block size for some smaller, simpler model\\n\",\n    \"        assert block_size <= self.config.block_size\\n\",\n    \"        self.config.block_size = block_size\\n\",\n    \"        self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])\\n\",\n    \"        for block in self.transformer.h:\\n\",\n    \"            if hasattr(block.attn, 'bias'):\\n\",\n    \"                block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]\\n\",\n    \"\\n\",\n    \"    @classmethod\\n\",\n    \"    def from_pretrained(cls, model_type, override_args=None):\\n\",\n    \"        assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}\\n\",\n    \"        override_args = override_args or {} # default to empty dict\\n\",\n    \"        # only dropout can be overridden see more notes below\\n\",\n    \"        assert all(k == 'dropout' for k in override_args)\\n\",\n    \"        from transformers import GPT2LMHeadModel\\n\",\n    \"        print(\\\"loading weights from pretrained gpt: %s\\\" % model_type)\\n\",\n    \"\\n\",\n    \"        # n_layer, n_head and n_embd are determined from model_type\\n\",\n    \"        config_args = {\\n\",\n    \"            'gpt2':         dict(n_layer=12, n_head=12, n_embd=768),  # 124M params\\n\",\n    \"            'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params\\n\",\n    \"            'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params\\n\",\n    \"            'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params\\n\",\n    \"        }[model_type]\\n\",\n    \"        print(\\\"forcing vocab_size=50257, block_size=1024, bias=True\\\")\\n\",\n    \"        config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints\\n\",\n    \"        config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints\\n\",\n    \"        config_args['bias'] = True # always True for GPT model checkpoints\\n\",\n    \"        # we can override the dropout rate, if desired\\n\",\n    \"        if 'dropout' in override_args:\\n\",\n    \"            print(f\\\"overriding dropout rate to {override_args['dropout']}\\\")\\n\",\n    \"            config_args['dropout'] = override_args['dropout']\\n\",\n    \"        # create a from-scratch initialized minGPT model\\n\",\n    \"        config = GPTConfig(**config_args)\\n\",\n    \"        model = GPT(config)\\n\",\n    \"        sd = model.state_dict()\\n\",\n    \"        sd_keys = sd.keys()\\n\",\n    \"        sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param\\n\",\n    \"\\n\",\n    \"        # init a huggingface/transformers model\\n\",\n    \"        model_hf = GPT2LMHeadModel.from_pretrained(model_type)\\n\",\n    \"        sd_hf = model_hf.state_dict()\\n\",\n    \"\\n\",\n    \"        # copy while ensuring all of the parameters are aligned and match in names and shapes\\n\",\n    \"        sd_keys_hf = sd_hf.keys()\\n\",\n    \"        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer\\n\",\n    \"        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)\\n\",\n    \"        transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']\\n\",\n    \"        # basically the openai checkpoints use a \\\"Conv1D\\\" module, but we only want to use a vanilla Linear\\n\",\n    \"        # this means that we have to transpose these weights when we import them\\n\",\n    \"        assert len(sd_keys_hf) == len(sd_keys), f\\\"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}\\\"\\n\",\n    \"        for k in sd_keys_hf:\\n\",\n    \"            if any(k.endswith(w) for w in transposed):\\n\",\n    \"                # special treatment for the Conv1D weights we need to transpose\\n\",\n    \"                assert sd_hf[k].shape[::-1] == sd[k].shape\\n\",\n    \"                with torch.no_grad():\\n\",\n    \"                    sd[k].copy_(sd_hf[k].t())\\n\",\n    \"            else:\\n\",\n    \"                # vanilla copy over the other parameters\\n\",\n    \"                assert sd_hf[k].shape == sd[k].shape\\n\",\n    \"                with torch.no_grad():\\n\",\n    \"                    sd[k].copy_(sd_hf[k])\\n\",\n    \"\\n\",\n    \"        return model\\n\",\n    \"\\n\",\n    \"    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):\\n\",\n    \"        # start with all of the candidate parameters\\n\",\n    \"        param_dict = {pn: p for pn, p in self.named_parameters()}\\n\",\n    \"        # filter out those that do not require grad\\n\",\n    \"        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}\\n\",\n    \"        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.\\n\",\n    \"        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.\\n\",\n    \"        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]\\n\",\n    \"        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]\\n\",\n    \"        optim_groups = [\\n\",\n    \"            {'params': decay_params, 'weight_decay': weight_decay},\\n\",\n    \"            {'params': nodecay_params, 'weight_decay': 0.0}\\n\",\n    \"        ]\\n\",\n    \"        num_decay_params = sum(p.numel() for p in decay_params)\\n\",\n    \"        num_nodecay_params = sum(p.numel() for p in nodecay_params)\\n\",\n    \"        print(f\\\"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters\\\")\\n\",\n    \"        print(f\\\"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters\\\")\\n\",\n    \"        # Create AdamW optimizer and use the fused version if it is available\\n\",\n    \"        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters\\n\",\n    \"        use_fused = fused_available and device_type == 'cuda'\\n\",\n    \"        extra_args = dict(fused=True) if use_fused else dict()\\n\",\n    \"        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)\\n\",\n    \"        print(f\\\"using fused AdamW: {use_fused}\\\")\\n\",\n    \"\\n\",\n    \"        return optimizer\\n\",\n    \"\\n\",\n    \"    def estimate_mfu(self, fwdbwd_per_iter, dt):\\n\",\n    \"        \\\"\\\"\\\" estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS \\\"\\\"\\\"\\n\",\n    \"        # first estimate the number of flops we do per iteration.\\n\",\n    \"        # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311\\n\",\n    \"        N = self.get_num_params()\\n\",\n    \"        cfg = self.config\\n\",\n    \"        L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size\\n\",\n    \"        flops_per_token = 6*N + 12*L*H*Q*T\\n\",\n    \"        flops_per_fwdbwd = flops_per_token * T\\n\",\n    \"        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter\\n\",\n    \"        flops_achieved = flops_per_iter * (1.0/dt) # per second\\n\",\n    \"        flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS\\n\",\n    \"        mfu = flops_achieved / flops_promised\\n\",\n    \"        return mfu\\n\",\n    \"\\n\",\n    \"    @torch.no_grad()\\n\",\n    \"    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete\\n\",\n    \"        the sequence max_new_tokens times, feeding the predictions back into the model each time.\\n\",\n    \"        Most likely you'll want to make sure to be in model.eval() mode of operation for this.\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        for _ in range(max_new_tokens):\\n\",\n    \"            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]\\n\",\n    \"            logits, _ = self(idx_cond)\\n\",\n    \"            logits = logits[:, -1, :] / temperature\\n\",\n    \"            if top_k is not None:\\n\",\n    \"                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))\\n\",\n    \"                logits[logits < v[:, [-1]]] = -float('Inf')\\n\",\n    \"            probs = F.softmax(logits, dim=-1)\\n\",\n    \"            idx_next = torch.multinomial(probs, num_samples=1)\\n\",\n    \"            idx = torch.cat((idx, idx_next), dim=1)\\n\",\n    \"\\n\",\n    \"        return idx\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Apply ZO to NanoGPT\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# create a ZO optimizer\\n\",\n    \"class Optimizer(MeZOSGD):\\n\",\n    \"\\n\",\n    \"    @torch.inference_mode\\n\",\n    \"    def inner_zo_forward(self, idx, pos, targets):\\n\",\n    \"        tok_emb = self.model.transformer.wte(idx)\\n\",\n    \"        pos_emb = self.model.transformer.wpe(pos)\\n\",\n    \"        x = tok_emb + pos_emb\\n\",\n    \"        for block in self.model.transformer.h:\\n\",\n    \"            x = block(x)\\n\",\n    \"        x = self.model.transformer.ln_f(x)\\n\",\n    \"        x = self.model.lm_head(x)\\n\",\n    \"        loss = F.cross_entropy(\\n\",\n    \"            x.reshape(-1, x.size(-1)), \\n\",\n    \"            targets.reshape(-1)\\n\",\n    \"        )\\n\",\n    \"        return loss.detach()\\n\",\n    \"\\n\",\n    \"    @torch.inference_mode()   \\n\",\n    \"    def inner_zo_eval_forward(self, eval_fn, idx, pos, targets):\\n\",\n    \"        output = eval_fn(idx, pos, targets)\\n\",\n    \"        return output\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# fused the ZO optimizer into model\\n\",\n    \"class ZOGPT(GPT, BaseZOModel):\\n\",\n    \"    def __init__(self, config: GPTConfig, zo_config: MeZOSGDConfig):\\n\",\n    \"        super().__init__(config)\\n\",\n    \"        self.opt = Optimizer(model=self, config=zo_config)\\n\",\n    \"\\n\",\n    \"    def forward(self, idx, pos, targets=None):\\n\",\n    \"        if self.zo_training:\\n\",\n    \"            return self.opt.zo_forward(idx, pos, targets)\\n\",\n    \"        else:\\n\",\n    \"            # for evaluate and inference purpose\\n\",\n    \"            return self.opt.zo_eval_forward(super().forward, idx, pos, targets)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Apply ZO2 to NanoGPT\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# create a ZO2 optimizer\\n\",\n    \"class Optimizer(MeZO2SGD):\\n\",\n    \"    \\n\",\n    \"    def init_zo2_upload(self):\\n\",\n    \"        print(\\\"Upload head and tail to cuda.\\\")\\n\",\n    \"        self.model.transformer.wte = self.model.transformer.wte.to(self.device)\\n\",\n    \"        self.model.transformer.wpe = self.model.transformer.wpe.to(self.device)\\n\",\n    \"        self.model.transformer.ln_f = self.model.transformer.ln_f.to(self.device)\\n\",\n    \"        self.model.lm_head = self.model.lm_head.to(self.device)\\n\",\n    \"        \\n\",\n    \"        self.num_blocks = len(self.model.transformer.h)\\n\",\n    \"        if self.offloading_blocks is not None:\\n\",\n    \"            self.offloading_blocks = self.offloading_blocks\\n\",\n    \"        else:\\n\",\n    \"            self.offloading_blocks = list(range(self.num_blocks))\\n\",\n    \"        print(f\\\"Transformer blocks {self.offloading_blocks} will be offloaded to {self.offloading_device}\\\")\\n\",\n    \"        for i in range(self.num_blocks):\\n\",\n    \"            if i in self.offloading_blocks:\\n\",\n    \"                continue\\n\",\n    \"            else:\\n\",\n    \"                self.model.transformer.h[i] = self.model.transformer.h[i].to(self.device)\\n\",\n    \"                print(f\\\"Upload block {i} to cuda.\\\")\\n\",\n    \"\\n\",\n    \"    @torch.inference_mode()   \\n\",\n    \"    def inner_zo_forward(self, idx, pos, targets):\\n\",\n    \"        we1, we2 = self.task_compute_module(self.model.transformer.wte,\\n\",\n    \"                                inputs1={\\\"input\\\": idx},\\n\",\n    \"                                inputs2={\\\"input\\\": idx},\\n\",\n    \"                                grad=self.projected_grad)\\n\",\n    \"            # only sync the compute stream at the first compute task\\n\",\n    \"        pe1, pe2 = self.task_compute_module(self.model.transformer.wpe, \\n\",\n    \"                                 {\\\"input\\\": pos}, \\n\",\n    \"                                 {\\\"input\\\": pos}, \\n\",\n    \"                                 self.projected_grad,\\n\",\n    \"                                 compute_sync=False)    \\n\",\n    \"            # disable the compute stream sync because we want all the compute tasks overlap with the following upload task\\n\",\n    \"        hidden_states1, hidden_states2 = self.task_compute_function(torch.add,\\n\",\n    \"                                                                    {\\\"input\\\": we1, \\\"other\\\": pe1},\\n\",\n    \"                                                                    {\\\"input\\\": we2, \\\"other\\\": pe2},\\n\",\n    \"                                                                    compute_sync=False)\\n\",\n    \"        if 0 in self.offloading_blocks:\\n\",\n    \"            self.model.transformer.h[0] = self.task_upload(\\n\",\n    \"                module=self.model.transformer.h[0], \\n\",\n    \"                device=self.device)\\n\",\n    \"        N = len(self.model.transformer.h)\\n\",\n    \"        for i in range(1, N):\\n\",\n    \"            # follow the rule that do offload the i-2-th block, compute the i-1-th block, and upload the i-th block in order.\\n\",\n    \"            if i != 1:\\n\",\n    \"                if i-2 in self.offloading_blocks:\\n\",\n    \"                    self.model.transformer.h[i-2] = self.task_offload(\\n\",\n    \"                        module=self.model.transformer.h[i-2], \\n\",\n    \"                        device=self.offloading_device)\\n\",\n    \"            hidden_states1, hidden_states2 = self.task_compute_module(\\n\",\n    \"                self.model.transformer.h[i-1], \\n\",\n    \"                inputs1={\\\"x\\\": hidden_states1}, \\n\",\n    \"                inputs2={\\\"x\\\": hidden_states2}, \\n\",\n    \"                grad=self.projected_grad)\\n\",\n    \"            if i in self.offloading_blocks:\\n\",\n    \"                self.model.transformer.h[i] = self.task_upload(\\n\",\n    \"                    module=self.model.transformer.h[i], \\n\",\n    \"                    device=self.device)\\n\",\n    \"        if N-2 in self.offloading_blocks:\\n\",\n    \"            self.model.transformer.h[N-2] = self.task_offload(\\n\",\n    \"                self.model.transformer.h[N-2], device=self.offloading_device)\\n\",\n    \"        hidden_states1, hidden_states2 = self.task_compute_module(\\n\",\n    \"                    self.model.transformer.h[N-1], \\n\",\n    \"                    inputs1={\\\"x\\\": hidden_states1}, \\n\",\n    \"                    inputs2={\\\"x\\\": hidden_states2}, \\n\",\n    \"                    grad=self.projected_grad\\n\",\n    \"                )\\n\",\n    \"        if N-1 in self.offloading_blocks:\\n\",\n    \"            self.model.transformer.h[N-1] = self.task_offload(\\n\",\n    \"                self.model.transformer.h[N-1], device=self.offloading_device)\\n\",\n    \"        logits1, logits2 = self.task_compute_module(self.model.transformer.ln_f,\\n\",\n    \"                                             inputs1={\\\"input\\\": hidden_states1}, \\n\",\n    \"                                             inputs2={\\\"input\\\": hidden_states2}, \\n\",\n    \"                                             grad=self.projected_grad,\\n\",\n    \"                                             weight_decay=0.)   \\n\",\n    \"            # 'task_compute_module' will remove the first name 'ln_f', so we need to disable weight_decay manually.\\n\",\n    \"        logits1, logits2 = self.task_compute_module(self.model.lm_head,\\n\",\n    \"                                             inputs1={\\\"input\\\": logits1}, \\n\",\n    \"                                             inputs2={\\\"input\\\": logits2}, \\n\",\n    \"                                             grad=self.projected_grad)\\n\",\n    \"        loss1, loss2 = self.task_compute_function(F.cross_entropy,\\n\",\n    \"                                                  {\\\"input\\\": logits1.reshape(-1, logits1.size(-1)), \\n\",\n    \"                                                   \\\"target\\\": targets.reshape(-1)},\\n\",\n    \"                                                  {\\\"input\\\": logits2.reshape(-1, logits2.size(-1)), \\n\",\n    \"                                                   \\\"target\\\": targets.reshape(-1)})\\n\",\n    \"        return loss1, loss2\\n\",\n    \"    \\n\",\n    \"    @torch.inference_mode()   \\n\",\n    \"    def inner_zo_eval_forward(self, eval_fn, idx, pos, targets):\\n\",\n    \"        handles = self.add_zo2_eval_comm_hooks(self.model.transformer.h)\\n\",\n    \"            # You can add zo2_eval_comm_hooks to all transformer blocks,\\n\",\n    \"            # but may be slower.\\n\",\n    \"        output = eval_fn(idx, pos, targets)\\n\",\n    \"        self.clear_zo2_eval_comm_hooks(handles)\\n\",\n    \"        return output\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# fused the ZO optimizer into model\\n\",\n    \"class ZO2GPT(GPT, BaseZOModel):\\n\",\n    \"    def __init__(self, config: GPTConfig, zo_config: MeZOSGDConfig):\\n\",\n    \"        super().__init__(config)\\n\",\n    \"        self.opt = Optimizer(model=self, config=zo_config)\\n\",\n    \"\\n\",\n    \"    def forward(self, idx, pos, targets=None):\\n\",\n    \"        if self.zo_training:\\n\",\n    \"            return self.opt.zo_forward(idx, pos, targets)\\n\",\n    \"        else:\\n\",\n    \"            # for evaluate and inference purpose\\n\",\n    \"            return self.opt.zo_eval_forward(super().forward, idx, pos, targets)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Train the Model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Hyperparameter\\n\",\n    \"zo_method = \\\"zo2\\\"\\n\",\n    \"eval_mode = False\\n\",\n    \"model_name = \\\"gpt2_xl\\\"\\n\",\n    \"verbose = True\\n\",\n    \"max_steps = 100\\n\",\n    \"learning_rate = 1e-4\\n\",\n    \"batch_size = 1\\n\",\n    \"weight_decay = 1e-1\\n\",\n    \"zo_eps = 1e-3\\n\",\n    \"seed = 42\\n\",\n    \"offloading_device = \\\"cpu\\\"\\n\",\n    \"working_device = \\\"cuda:0\\\"\\n\",\n    \"use_cache = True\\n\",\n    \"max_new_tokens = 50\\n\",\n    \"temperature = 1.0\\n\",\n    \"seed_everything(seed)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# ZO steps\\n\",\n    \"zo_config = ZOConfig(\\n\",\n    \"    method=\\\"mezo-sgd\\\", \\n\",\n    \"    zo2=zo_method==\\\"zo2\\\", \\n\",\n    \"    lr=learning_rate,\\n\",\n    \"    weight_decay=weight_decay,\\n\",\n    \"    eps=zo_eps,\\n\",\n    \"    offloading_device=offloading_device,\\n\",\n    \"    working_device=working_device,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# Load ZO model\\n\",\n    \"class GPTConfigs:\\n\",\n    \"    gpt2: GPTConfig = GPTConfig(n_layer=12, n_head=12, n_embd=768)\\n\",\n    \"    gpt2_medium: GPTConfig = GPTConfig(n_layer=24, n_head=16, n_embd=1024)\\n\",\n    \"    gpt2_large: GPTConfig = GPTConfig(n_layer=36, n_head=20, n_embd=1280)\\n\",\n    \"    gpt2_xl: GPTConfig = GPTConfig(n_layer=48, n_head=25, n_embd=1600)\\n\",\n    \"cfgs = GPTConfigs()\\n\",\n    \"model_cfg = getattr(cfgs, model_name)\\n\",\n    \"MODEL_CLASS = ZO2GPT if zo_method==\\\"zo2\\\" else ZOGPT\\n\",\n    \"model = MODEL_CLASS(config=model_cfg, zo_config=zo_config)\\n\",\n    \"if zo_method != \\\"zo2\\\": \\n\",\n    \"    model = model.to(working_device)\\n\",\n    \"print(f\\\"Check if zo2 init correctly: {hasattr(model, 'zo_training')}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# prepare some data (random generated)\\n\",\n    \"B, V, T = batch_size, model_cfg.vocab_size, model_cfg.block_size\\n\",\n    \"data_batch = torch.randint(0, V, (B, T+1)).to(working_device)\\n\",\n    \"input_ids = data_batch[:, :T]   # shift data and labels\\n\",\n    \"labels = data_batch[:, 1:T+1]\\n\",\n    \"pos = torch.arange(input_ids.shape[1], dtype=torch.long, device=working_device).unsqueeze(0)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# training loop\\n\",\n    \"for i in tqdm(range(max_steps)):\\n\",\n    \"    model.zo_train()\\n\",\n    \"    loss = model(input_ids, pos, labels)\\n\",\n    \"    res = \\\"Iteration {}, loss: {}, projected grad: {}\\\"\\n\",\n    \"    tqdm.write(res.format(i, loss, model.opt.projected_grad))\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"mezo\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.9\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "zo2/README.md",
    "content": "# Core code of ZO2\n\n## Features\n\n1. Fuse model dual-forward and optimizer step into model forward code. For example,\n\n```python\n# first-order one training step:\nmodel.train()\nloss = model(input, label)\t# forward\nloss.backward()\t\t# backward\noptimizer.step()\t# update parameters, optimizer states\n\n# zo2 one training step:\nmodel.zo_train()\t# Enable zo training\nloss = model(input, label)\t# fuse dual-forward, parameters and optimizer states updates\n```\n\n## Code Logic\n\n1. Fuse model dual-forward and optimizer step into model forward code.\n\n## In progress..."
  },
  {
    "path": "zo2/__init__.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n# configs\nfrom .config import ZOConfig\n\n# model\nfrom .model.nanogpt.mezo_sgd import get_nanogpt_mezo_sgd\n\nfrom .model.huggingface.zo_init import zo_hf_init\nfrom .model.huggingface.opt import (\n    get_opt_for_causalLM,\n    get_opt_for_sequence_classification,\n    get_opt_for_question_answering\n)"
  },
  {
    "path": "zo2/config/__init__.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nfrom .mezo_sgd import MeZOSGDConfig\n\n\ndef ZOConfig(method: str = \"mezo-sgd\", **kwargs):\n    match method:\n        case \"mezo-sgd\":\n            return MeZOSGDConfig(**kwargs)\n        # case \"another-method\":\n        #     return AnotherConfig(**kwargs)\n        case _:\n            raise ValueError(f\"Unsupported method {method}\")\n"
  },
  {
    "path": "zo2/config/mezo_sgd.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\nfrom dataclasses import dataclass\n\n@dataclass\nclass MeZOSGDConfig:\n    # zo method\n    zo_method: str = \"mezo-sgd\" # zo method name, every zo config must include this attribute\n\n    # zo config\n    lr: float = 1e-3\n    weight_decay: float = 1e-1\n    eps: float = 1e-3\n    max_zo_random_seed = 1000000000\n\n    # zo2 config\n    zo2: bool = True    # use offloading or not\n    offloading_blocks: list = None  # specify offloading blocks or not\n    offloading_device: str = 'cpu'  # offload device, can be CPU or a path (for disk offloading, but currently unavailable)\n    working_device: str = 'cuda'    # compute device, can be any CUDA device\n    overlap: bool = True    # use scheduler to overlap or not\n    compute_module_optimize_method: str = ''   # possible values are: ['', 'torch.compile']\n    compute_function_optimize_method: str = ''   # possible values are: ['', 'torch.jit.script']\n    communicate_optimize_method: str = ''   # possible values are: ['', 'bucket']\n    amp: bool = False   # use amp or not\n    amp_precision: torch.dtype = torch.bfloat16 # amp autocast precision, possible values are: [torch.bfloat16, torch.float32], valid when using amp\n    precision_on_offloading_device: torch.dtype = torch.float16 # precision on offloading device, valid when using amp\n    precision_on_working_device: torch.dtype = torch.float32    # precision on working device, valid when using amp\n    amp_compress_method: str = 'naive'  # currently only support naive amp compress, valid when using amp\n\n    # debug\n    debug_mode: bool = False    # set 'True' to disable random noise"
  },
  {
    "path": "zo2/model/__init__.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n"
  },
  {
    "path": "zo2/model/base.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\n\nclass BaseZOModel(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.zo_training = True\n        self.zo_train_loss_fn_pre_hooks = []\n        self.zo_train_loss_fn_post_hooks = []\n        self.zo_eval_loss_fn_pre_hooks = []\n        self.zo_eval_loss_fn_post_hooks = []\n        self.zo_custom_train_loss_fn = None\n        self.zo_custom_eval_loss_fn = None\n\n    def zo_train(self):\n        \"\"\"\n            Zeroth-order training\n        \"\"\"\n        self.zo_training = True\n        self.eval()\n\n    def zo_eval(self):\n        \"\"\"\n            Zeroth-order evaluation\n        \"\"\"\n        self.zo_training = False\n        self.eval()\n\n    def register_zo_train_loss_fn_pre_hook(self, hook_fn):\n        self.zo_train_loss_fn_pre_hooks.append(hook_fn)\n\n    def register_zo_train_loss_fn_post_hook(self, hook_fn):\n        self.zo_train_loss_fn_post_hooks.append(hook_fn)\n\n    def register_zo_eval_loss_fn_pre_hook(self, hook_fn):\n        self.zo_eval_loss_fn_pre_hooks.append(hook_fn)\n\n    def register_zo_eval_loss_fn_post_hook(self, hook_fn):\n        self.zo_eval_loss_fn_post_hooks.append(hook_fn)\n\n    def register_custom_opt(self, custom_opt_obj):\n        if hasattr(self, \"opt\"):\n            self.opt = custom_opt_obj\n        for module in self.children():\n            if hasattr(module, \"opt\"):\n                module.opt = custom_opt_obj"
  },
  {
    "path": "zo2/model/huggingface/__init__.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n"
  },
  {
    "path": "zo2/model/huggingface/gpt/mezo_sgd/zo.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n"
  },
  {
    "path": "zo2/model/huggingface/gpt/mezo_sgd/zo2.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n"
  },
  {
    "path": "zo2/model/huggingface/llama/mezo_sgd/zo.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n\nfrom transformers.models.llama import modeling_llama"
  },
  {
    "path": "zo2/model/huggingface/llama/mezo_sgd/zo2.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n\nfrom transformers.models.llama import modeling_llama"
  },
  {
    "path": "zo2/model/huggingface/opt/__init__.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nfrom . import (\n    mezo_sgd,\n)\n\ndef get_opt_for_causalLM(zo_config):\n    zo2_supported_configs = {\n        \"mezo-sgd\": mezo_sgd.get_opt_for_causalLM_mezo_sgd,\n    }\n    return zo2_supported_configs[zo_config.zo_method](zo_config)\n\ndef get_opt_for_sequence_classification(zo_config):\n    zo2_supported_configs = {\n        \"mezo-sgd\": mezo_sgd.get_opt_for_sequence_classification_mezo_sgd,\n    }\n    return zo2_supported_configs[zo_config.zo_method](zo_config)\n\ndef get_opt_for_question_answering(zo_config):\n    zo2_supported_configs = {\n        \"mezo-sgd\": mezo_sgd.get_opt_for_question_answering_mezo_sgd,\n    }\n    return zo2_supported_configs[zo_config.zo_method](zo_config)\n"
  },
  {
    "path": "zo2/model/huggingface/opt/mezo_sgd/__init__.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nfrom . import zo, zo2\nfrom .....config.mezo_sgd import MeZOSGDConfig\n\ndef get_opt_for_causalLM_mezo_sgd(config: MeZOSGDConfig):\n    return zo2.OPTForCausalLM if config.zo2 else zo.OPTForCausalLM\n\ndef get_opt_for_sequence_classification_mezo_sgd(config: MeZOSGDConfig):\n    return zo2.OPTForSequenceClassification if config.zo2 else zo.OPTForSequenceClassification\n\ndef get_opt_for_question_answering_mezo_sgd(config: MeZOSGDConfig):\n    return zo2.OPTForQuestionAnswering if config.zo2 else zo.OPTForQuestionAnswering\n"
  },
  {
    "path": "zo2/model/huggingface/opt/mezo_sgd/utils.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\n\ndef fn_get_opt_decoder_hidden_states_from_layer_outputs(input):\n    return input[0]\n\ndef get_shift_logits(logits):\n    return logits[..., :-1, :].contiguous()\n\ndef get_shift_labels(labels):\n    return labels[..., 1:].contiguous()\n\ndef get_pooled_logits(logits, batch_size, sequence_lengths):\n    return logits[torch.arange(batch_size, device=logits.device), sequence_lengths]\n\ndef get_start_logits_and_end_logits(logits):\n    start_logits, end_logits = logits.split(1, dim=-1)\n    start_logits = start_logits.squeeze(-1).contiguous()\n    end_logits = end_logits.squeeze(-1).contiguous()\n    return start_logits, end_logits\n\ndef get_qa_loss(loss_fct, start_logits, start_positions, end_logits, end_positions):\n    start_loss = loss_fct(start_logits, start_positions)\n    end_loss = loss_fct(end_logits, end_positions)\n    total_loss = (start_loss + end_loss) / 2\n    return total_loss\n\ndef init_all_hidden_states(output_hidden_states):\n    return () if output_hidden_states else None\n\ndef init_all_self_attns(output_attentions):\n    return () if output_attentions else None\n\ndef init_next_decoder_cache(use_cache):\n    return () if use_cache else None\n\ndef update_next_decoder_cache(use_cache, next_decoder_cache, layer_outputs, output_attentions):\n    if use_cache:\n        next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)\n    return next_decoder_cache\n\ndef update_all_self_attns(output_attentions, all_self_attns, layer_outputs):\n    if output_attentions:\n        all_self_attns += (layer_outputs[1],)\n    return all_self_attns\n\ndef update_all_hidden_states(output_hidden_states, all_hidden_states, hidden_states):\n    if output_hidden_states:\n        all_hidden_states += (hidden_states,)\n    return all_hidden_states\n\ndef get_past_key_value(past_key_values, idx):\n    return past_key_values[idx] if past_key_values is not None else None\n\ndef get_opt_sequence_classification_pooled_logits(self, logits, input_ids, inputs_embeds):\n    if input_ids is not None:\n        batch_size, sequence_length = input_ids.shape[:2]\n    else:\n        batch_size, sequence_length = inputs_embeds.shape[:2]\n    if self.config.pad_token_id is None:\n        sequence_lengths = -1\n    else:\n        if input_ids is not None:\n            sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)\n        else:\n            sequence_lengths = -1\n    return logits[torch.arange(batch_size, device=logits.device), sequence_lengths]\n\ndef get_opt_sequence_classification_loss(self, loss, pooled_logits, labels):\n    if self.config.problem_type is None:\n        if self.num_labels == 1:\n            self.config.problem_type = \"regression\"\n        elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n            self.config.problem_type = \"single_label_classification\"\n        else:\n            self.config.problem_type = \"multi_label_classification\"\n    if self.config.problem_type == \"regression\":\n        loss_fct = torch.nn.MSELoss()\n        if self.num_labels == 1:\n            loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())\n        else:\n            loss = loss_fct(pooled_logits, labels)\n    elif self.config.problem_type == \"single_label_classification\":\n        loss_fct = torch.nn.CrossEntropyLoss()\n        loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))\n    elif self.config.problem_type == \"multi_label_classification\":\n        loss_fct = torch.nn.BCEWithLogitsLoss()\n        loss = loss_fct(pooled_logits, labels)\n    return loss\n\ndef get_opt_question_answering_start_end_logits(logits):\n    start_logits, end_logits = logits.split(1, dim=-1)\n    start_logits = start_logits.squeeze(-1).contiguous()\n    end_logits = end_logits.squeeze(-1).contiguous()\n    return start_logits, end_logits\n\ndef get_opt_question_answering_loss(total_loss, start_logits, start_positions, end_logits, end_positions):\n    # If we are on multi-GPU, split add a dimension\n    if len(start_positions.size()) > 1:\n        start_positions = start_positions.squeeze(-1)\n    if len(end_positions.size()) > 1:\n        end_positions = end_positions.squeeze(-1)\n    # sometimes the start/end positions are outside our model inputs, we ignore these terms\n    ignored_index = start_logits.size(1)\n    start_positions = start_positions.clamp(0, ignored_index)\n    end_positions = end_positions.clamp(0, ignored_index)\n\n    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignored_index)\n    start_loss = loss_fct(start_logits, start_positions)\n    end_loss = loss_fct(end_logits, end_positions)\n    total_loss = (start_loss + end_loss) / 2\n    return total_loss"
  },
  {
    "path": "zo2/model/huggingface/opt/mezo_sgd/zo.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n\nfrom transformers.models.opt import modeling_opt\nfrom transformers.models.opt.modeling_opt import (\n    OPTConfig,\n    OPTPreTrainedModel,\n    OPTLearnedPositionalEmbedding,\n    OPTDecoderLayer,\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n    SequenceClassifierOutputWithPast,\n    QuestionAnsweringModelOutput,\n    add_start_docstrings_to_model_forward,\n    add_code_sample_docstrings,\n    replace_return_docstrings,\n    OPT_INPUTS_DOCSTRING,\n    _CHECKPOINT_FOR_DOC,\n    _CONFIG_FOR_DOC,\n    _EXPECTED_OUTPUT_SHAPE,\n    _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,\n    _SEQ_CLASS_EXPECTED_OUTPUT,\n    _SEQ_CLASS_EXPECTED_LOSS,\n)\nfrom transformers.utils import logging\n\nimport random\nfrom typing import List, Optional, Tuple, Union\n\nfrom ....base import BaseZOModel\nfrom .....optimizer.mezo_sgd.zo import MeZOSGD\nfrom .....config.mezo_sgd import MeZOSGDConfig\n\nlogger = logging.get_logger(__name__)\n\n\n\nclass OPTDecoder(modeling_opt.OPTDecoder, OPTPreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`]\n    \n    Args:\n        config: OPTConfig\n    \"\"\"\n\n    def __init__(self, config: OPTConfig):\n        \"\"\"\n        !!! Module register must follow the execution order.\n        \"\"\"\n        OPTPreTrainedModel.__init__(self, config)\n        self.dropout = config.dropout\n        self.layerdrop = config.layerdrop\n        self.padding_idx = config.pad_token_id\n        self.max_target_positions = config.max_position_embeddings\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx)\n        self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)\n\n        if config.word_embed_proj_dim != config.hidden_size:\n            self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)\n        else:\n            self.project_in = None\n\n        self.layers = nn.ModuleList([OPTDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])\n\n        # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility\n        # with checkpoints that have been fine-tuned before transformers v4.20.1\n        # see https://github.com/facebookresearch/metaseq/pull/164\n        if config.do_layer_norm_before and not config._remove_final_layer_norm:\n            self.final_layer_norm = nn.LayerNorm(\n                config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine\n            )\n        else:\n            self.final_layer_norm = None\n\n        if config.word_embed_proj_dim != config.hidden_size:\n            self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)\n        else:\n            self.project_out = None\n\n        self.gradient_checkpointing = False\n        # Initialize weights and apply final processing\n        self.post_init()\n\n\nclass OPTModel(modeling_opt.OPTModel, OPTPreTrainedModel):\n    def __init__(self, config: OPTConfig):\n        OPTPreTrainedModel.__init__(self, config)\n        self.decoder = OPTDecoder(config)\n        # Initialize weights and apply final processing\n        self.post_init()\n\n\nclass OPTForCausalLM(modeling_opt.OPTForCausalLM, OPTPreTrainedModel, BaseZOModel):\n    _keys_to_ignore_on_load_missing = [r\"lm_head.weight\"]\n\n    def __init__(self, config: OPTConfig):\n        OPTPreTrainedModel.__init__(self, config)\n        BaseZOModel.__init__(self)\n        self.model = OPTModel(config)\n\n        # the lm_head weight is automatically tied to the embed tokens weight\n        self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def zo_init(self, zo_config):\n        self.opt = OptimizerOPTForCausalLM(model=self, config=zo_config)\n\n    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        cache_position: Optional[torch.Tensor] = None,\n        **kwargs\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you\n                provide it.\n\n                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n                [`PreTrainedTokenizer.__call__`] for details.\n\n                [What are input IDs?](../glossary#input-ids)\n            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n                - 1 for tokens that are **not masked**,\n                - 0 for tokens that are **masked**.\n\n                [What are attention masks?](../glossary#attention-mask)\n            head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):\n                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:\n\n                - 1 indicates the head is **not masked**,\n                - 0 indicates the head is **masked**.\n\n            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):\n                Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n                shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of\n                shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional\n                tensors are only required when the model is used as a decoder in a Sequence to Sequence model.\n\n                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the\n                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.\n\n                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those\n                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of\n                all `decoder_input_ids` of shape `(batch_size, sequence_length)`.\n            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.\n                This is useful if you want more control over how to convert `input_ids` indices into associated vectors\n                than the model's internal embedding lookup matrix.\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            output_hidden_states (`bool`, *optional*):\n                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors\n                for more detail.\n            return_dict (`bool`, *optional*):\n                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, OPTForCausalLM\n\n        >>> model = OPTForCausalLM.from_pretrained(\"facebook/opt-350m\")\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"facebook/opt-350m\")\n\n        >>> prompt = \"Hey, are you consciours? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you consciours? Can you talk to me?\\nI'm not consciours, but I can talk to you.\"\n        ```\"\"\"\n\n        if self.zo_training:\n            return self.opt.zo_forward(\n                input_ids, attention_mask, head_mask, \n                past_key_values, inputs_embeds, labels, use_cache, \n                output_attentions, output_hidden_states, return_dict, \n                position_ids, cache_position, **kwargs)\n        else:\n            return self.opt.zo_eval_forward(super().forward, \n                input_ids, attention_mask, head_mask, \n                past_key_values, inputs_embeds, labels, use_cache, \n                output_attentions, output_hidden_states, return_dict, \n                position_ids, cache_position, **kwargs)\n\n\nclass OPTForSequenceClassification(modeling_opt.OPTForSequenceClassification, OPTPreTrainedModel, BaseZOModel):\n    def __init__(self, config: OPTConfig):\n        OPTPreTrainedModel.__init__(self, config)\n        BaseZOModel.__init__(self)\n        self.num_labels = config.num_labels\n        self.model = OPTModel(config)\n        self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def zo_init(self, zo_config):\n        self.opt = OptimizerOPTForSequenceClassification(model=self, config=zo_config)\n\n    @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)\n    @add_code_sample_docstrings(\n        checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,\n        output_type=SequenceClassifierOutputWithPast,\n        config_class=_CONFIG_FOR_DOC,\n        expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,\n        expected_loss=_SEQ_CLASS_EXPECTED_LOSS,\n    )\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        if self.zo_training:\n            return self.opt.zo_forward(\n                input_ids, attention_mask, head_mask, \n                past_key_values, inputs_embeds, labels, use_cache, \n                output_attentions, output_hidden_states, return_dict)\n        else:\n            return self.opt.zo_eval_forward(super().forward, \n                input_ids, attention_mask, head_mask, \n                past_key_values, inputs_embeds, labels, use_cache, \n                output_attentions, output_hidden_states, return_dict)\n\n\nclass OPTForQuestionAnswering(modeling_opt.OPTForQuestionAnswering, OPTPreTrainedModel, BaseZOModel):\n    def __init__(self, config: OPTConfig):\n        OPTPreTrainedModel.__init__(self, config)\n        BaseZOModel.__init__(self)\n        self.model = OPTModel(config)\n        self.qa_outputs = nn.Linear(config.word_embed_proj_dim, 2)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n    \n    def zo_init(self, zo_config):\n        self.opt = OptimizerOPTForQuestionAnswering(model=self, config=zo_config)\n\n    @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)\n    @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        start_positions: Optional[torch.LongTensor] = None,\n        end_positions: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, QuestionAnsweringModelOutput]:\n        r\"\"\"\n        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for position (index) of the start of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence\n            are not taken into account for computing the loss.\n        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for position (index) of the end of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence\n            are not taken into account for computing the loss.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, OPTForQuestionAnswering\n        >>> import torch\n\n        >>> torch.manual_seed(4)  # doctest: +IGNORE_RESULT\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"facebook/opt-350m\")\n\n        >>> # note: we are loading a OPTForQuestionAnswering from the hub here,\n        >>> # so the head will be randomly initialized, hence the predictions will be random\n        >>> model = OPTForQuestionAnswering.from_pretrained(\"facebook/opt-350m\")\n\n        >>> question, text = \"Who was Jim Henson?\", \"Jim Henson was a nice puppet\"\n\n        >>> inputs = tokenizer(question, text, return_tensors=\"pt\")\n        >>> with torch.no_grad():\n        ...     outputs = model(**inputs)\n\n        >>> answer_start_index = outputs.start_logits.argmax()\n        >>> answer_end_index = outputs.end_logits.argmax()\n\n        >>> answer_offset = len(tokenizer(question)[0])\n\n        >>> predict_answer_tokens = inputs.input_ids[\n        ...     0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1\n        ... ]\n        >>> predicted = tokenizer.decode(predict_answer_tokens)\n        >>> predicted\n        ' a nice puppet'\n        ```\"\"\"\n        if self.zo_training:\n            return self.opt.zo_forward(\n                input_ids, attention_mask, head_mask, \n                past_key_values, inputs_embeds, start_positions, end_positions, use_cache, \n                output_attentions, output_hidden_states, return_dict)\n        else:\n            return self.opt.zo_eval_forward(super().forward, \n                input_ids, attention_mask, head_mask, \n                past_key_values, inputs_embeds, start_positions, end_positions, use_cache, \n                output_attentions, output_hidden_states, return_dict)\n\n\nclass OptimizerOPTForCausalLM(MeZOSGD):\n    \n    @torch.inference_mode\n    def inner_zo_forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        cache_position: Optional[torch.Tensor] = None,\n        **kwargs\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        \"\"\"\n            copy the original forward code and replace all 'self' to 'self.model'.\n        \"\"\"\n\n        output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model.model.decoder(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            position_ids=position_ids,\n            cache_position=cache_position,\n        )\n\n        logits = self.model.lm_head(outputs[0]).contiguous()\n\n        if self.model.zo_train_loss_fn_pre_hooks != []:\n            for pre_hook_fn in self.model.zo_train_loss_fn_pre_hooks:\n                input_ids, logits, labels = pre_hook_fn(self.model, input_ids, logits, labels)\n\n        loss = None\n        if self.model.zo_custom_train_loss_fn:\n            loss = self.model.zo_custom_train_loss_fn(self.model, input_ids, logits, labels, **kwargs)\n        elif labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            loss = loss_fct(shift_logits.view(-1, self.model.config.vocab_size), shift_labels.view(-1))\n\n        if self.model.zo_train_loss_fn_post_hooks != []:\n            for post_hook_fn in self.model.zo_train_loss_fn_post_hooks:\n                loss, input_ids, logits, labels = post_hook_fn(self.model, loss, input_ids, logits, labels)\n\n        # add --> only return loss\n        return loss.detach()\n\n    @torch.inference_mode()   \n    def inner_zo_eval_forward(\n        self,\n        eval_fn,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        cache_position: Optional[torch.Tensor] = None,\n        **kwargs\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        if self.model.zo_eval_loss_fn_pre_hooks != []:\n            for pre_hook_fn in self.model.zo_eval_loss_fn_pre_hooks:\n                input_ids, logits, labels = pre_hook_fn(self.model, input_ids, logits, labels)\n\n        if self.model.zo_custom_eval_loss_fn:\n            output = eval_fn(input_ids, attention_mask, head_mask, \n                past_key_values, inputs_embeds, None, use_cache, \n                output_attentions, output_hidden_states, return_dict)\n            if not return_dict:\n                logits = output[0]\n                loss = self.model.zo_custom_eval_loss_fn(self.model, input_ids, logits, labels, **kwargs)\n                output = (logits,) + output[1]\n                return (loss,) + output if loss is not None else output\n            logits = output[\"logits\"]\n            loss = self.model.zo_custom_eval_loss_fn(self.model, input_ids, logits, labels, **kwargs)\n            output = CausalLMOutputWithPast(\n                loss=loss,\n                logits=logits,\n                past_key_values=output[\"past_key_values\"],\n                hidden_states=output[\"hidden_states\"],\n                attentions=output[\"attentions\"],\n            )\n        else:\n            output = eval_fn(input_ids, attention_mask, head_mask, \n                past_key_values, inputs_embeds, labels, use_cache, \n                output_attentions, output_hidden_states, return_dict,\n                position_ids, cache_position)\n        \n        if self.model.zo_eval_loss_fn_post_hooks != []:\n            for post_hook_fn in self.model.zo_eval_loss_fn_post_hooks:\n                output, input_ids, logits, labels = post_hook_fn(self.model, output, input_ids, logits, labels)\n        return output\n    \n\nclass OptimizerOPTForSequenceClassification(MeZOSGD):\n\n    @torch.inference_mode\n    def inner_zo_forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        **kwargs\n    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:\n        \"\"\"\n            copy the original forward code and replace all 'self' to 'self.model'.\n        \"\"\"\n\n        return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict\n\n        transformer_outputs = self.model.model(\n            input_ids,\n            past_key_values=past_key_values,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n        logits = self.model.score(hidden_states)\n\n        if input_ids is not None:\n            batch_size, sequence_length = input_ids.shape[:2]\n        else:\n            batch_size, sequence_length = inputs_embeds.shape[:2]\n\n        if self.model.config.pad_token_id is None:\n            sequence_lengths = -1\n        else:\n            if input_ids is not None:\n                sequence_lengths = (torch.ne(input_ids, self.model.config.pad_token_id).sum(-1) - 1).to(logits.device)\n            else:\n                sequence_lengths = -1\n                logger.warning(\n                    f\"{self.model.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be \"\n                    \"unexpected if using padding tokens in conjunction with `inputs_embeds.`\"\n                )\n\n        pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]\n\n        if self.model.zo_train_loss_fn_pre_hooks != []:\n            for pre_hook_fn in self.model.zo_train_loss_fn_pre_hooks:\n                input_ids, logits, labels = pre_hook_fn(self.model, input_ids, logits, labels)\n\n        loss = None\n        if self.model.zo_custom_train_loss_fn:\n            loss = self.model.zo_custom_train_loss_fn(self.model, input_ids, logits, labels, **kwargs)\n        elif labels is not None:\n            if self.model.config.problem_type is None:\n                if self.model.num_labels == 1:\n                    self.model.config.problem_type = \"regression\"\n                elif self.model.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n                    self.model.config.problem_type = \"single_label_classification\"\n                else:\n                    self.model.config.problem_type = \"multi_label_classification\"\n\n            if self.model.config.problem_type == \"regression\":\n                loss_fct = MSELoss()\n                if self.model.num_labels == 1:\n                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())\n                else:\n                    loss = loss_fct(pooled_logits, labels)\n            elif self.model.config.problem_type == \"single_label_classification\":\n                loss_fct = CrossEntropyLoss()\n                loss = loss_fct(pooled_logits.view(-1, self.model.num_labels), labels.view(-1))\n            elif self.model.config.problem_type == \"multi_label_classification\":\n                loss_fct = BCEWithLogitsLoss()\n                loss = loss_fct(pooled_logits, labels)\n        \n        if self.model.zo_train_loss_fn_post_hooks != []:\n            for post_hook_fn in self.model.zo_train_loss_fn_post_hooks:\n                loss, input_ids, logits, labels = post_hook_fn(self.model, loss, input_ids, logits, labels)\n\n        # add --> only return loss\n        if self.model.zo_training:\n            return loss.detach()\n        \n    @torch.inference_mode\n    def inner_zo_eval_forward(\n        self,\n        eval_fn,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        **kwargs\n    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:\n        if self.model.zo_eval_loss_fn_pre_hooks != []:\n            for pre_hook_fn in self.model.zo_eval_loss_fn_pre_hooks:\n                input_ids, logits, labels = pre_hook_fn(self.model, input_ids, logits, labels)\n\n        if self.model.zo_custom_eval_loss_fn:\n            output = eval_fn(input_ids, attention_mask, head_mask, \n                past_key_values, inputs_embeds, None, use_cache, \n                output_attentions, output_hidden_states, return_dict)\n            if not return_dict:\n                logits = output[0]\n                loss = self.model.zo_custom_eval_loss_fn(self.model, input_ids, logits, labels, **kwargs)\n                output = (logits,) + output[1]\n                return (loss,) + output if loss is not None else output\n            logits = output[\"logits\"]\n            loss = self.model.zo_custom_eval_loss_fn(self.model, input_ids, logits, labels, **kwargs)\n            output = CausalLMOutputWithPast(\n                loss=loss,\n                logits=logits,\n                past_key_values=output[\"past_key_values\"],\n                hidden_states=output[\"hidden_states\"],\n                attentions=output[\"attentions\"],\n            )\n        else:\n            output = eval_fn(input_ids, attention_mask, head_mask, \n                past_key_values, inputs_embeds, labels, use_cache, \n                output_attentions, output_hidden_states, return_dict)\n        \n        if self.model.zo_eval_loss_fn_post_hooks != []:\n            for post_hook_fn in self.model.zo_eval_loss_fn_post_hooks:\n                output, input_ids, logits, labels = post_hook_fn(self.model, output, input_ids, logits, labels)\n        return output\n\n\nclass OptimizerOPTForQuestionAnswering(MeZOSGD):\n    \n    @torch.inference_mode\n    def inner_zo_forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        start_positions: Optional[torch.LongTensor] = None,\n        end_positions: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        **kwargs\n    ) -> Union[Tuple, QuestionAnsweringModelOutput]:\n        \"\"\"\n            copy the original forward code and replace all 'self' to 'self.model'.\n        \"\"\"\n        \n        return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict\n\n        transformer_outputs = self.model.model(\n            input_ids,\n            past_key_values=past_key_values,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n\n        logits = self.model.qa_outputs(hidden_states)\n        start_logits, end_logits = logits.split(1, dim=-1)\n        start_logits = start_logits.squeeze(-1).contiguous()\n        end_logits = end_logits.squeeze(-1).contiguous()\n\n        if self.model.zo_train_loss_fn_pre_hooks != []:\n            for pre_hook_fn in self.model.zo_train_loss_fn_pre_hooks:\n                input_ids, start_logits, start_positions, end_logits, end_positions = \\\n                    pre_hook_fn(self.model, input_ids, start_logits, start_positions, end_logits, end_positions)\n\n        total_loss = None\n        if self.model.zo_custom_train_loss_fn:\n            loss = self.model.zo_custom_train_loss_fn(self.model, input_ids, start_logits, start_positions, end_logits, end_positions, **kwargs)\n        elif start_positions is not None and end_positions is not None:\n            # If we are on multi-GPU, split add a dimension\n            if len(start_positions.size()) > 1:\n                start_positions = start_positions.squeeze(-1)\n            if len(end_positions.size()) > 1:\n                end_positions = end_positions.squeeze(-1)\n            # sometimes the start/end positions are outside our model inputs, we ignore these terms\n            ignored_index = start_logits.size(1)\n            start_positions = start_positions.clamp(0, ignored_index)\n            end_positions = end_positions.clamp(0, ignored_index)\n\n            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)\n            start_loss = loss_fct(start_logits, start_positions)\n            end_loss = loss_fct(end_logits, end_positions)\n            total_loss = (start_loss + end_loss) / 2\n\n        if self.model.zo_train_loss_fn_post_hooks != []:\n            for post_hook_fn in self.model.zo_train_loss_fn_post_hooks:\n                loss, input_ids, start_logits, start_positions, end_logits, end_positions = \\\n                    post_hook_fn(self.model, loss, input_ids, start_logits, start_positions, end_logits, end_positions)\n\n        # add --> only return loss\n        if self.model.zo_training:\n            return total_loss.detach()\n        \n    @torch.inference_mode\n    def inner_zo_eval_forward(\n        self,\n        eval_fn,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        start_positions: Optional[torch.LongTensor] = None,\n        end_positions: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        **kwargs\n    ) -> Union[Tuple, QuestionAnsweringModelOutput]:\n        if self.model.zo_eval_loss_fn_pre_hooks != []:\n            for pre_hook_fn in self.model.zo_eval_loss_fn_pre_hooks:\n                input_ids, start_logits, start_positions, end_logits, end_positions = pre_hook_fn(self.model, input_ids, start_logits, start_positions, end_logits, end_positions)\n\n        if self.model.zo_custom_eval_loss_fn:\n            output = eval_fn(input_ids, attention_mask, head_mask, \n                past_key_values, inputs_embeds, None, None, use_cache, \n                output_attentions, output_hidden_states, return_dict)\n            if not return_dict:\n                start_logits, end_logits = output[0], output[1]\n                loss = self.model.zo_custom_eval_loss_fn(self.model, input_ids, start_logits, start_positions, end_logits, end_positions, **kwargs)\n                output = (start_logits, end_logits) + output[2:]\n                return (loss,) + output if loss is not None else output\n            start_logits = output[\"start_logits\"]\n            end_logits = output[\"end_logits\"]\n            loss = self.model.zo_custom_eval_loss_fn(self.model, input_ids, start_logits, start_positions, end_logits, end_positions, **kwargs)\n            output = QuestionAnsweringModelOutput(\n                loss=loss,\n                start_logits=start_logits,\n                end_logits=end_logits,\n                hidden_states=output[\"hidden_states\"],\n                attentions=output[\"attentions\"],\n            )\n        else:\n            output = eval_fn(input_ids, attention_mask, head_mask, \n                past_key_values, inputs_embeds, start_positions, end_positions, use_cache, \n                output_attentions, output_hidden_states, return_dict)\n        \n        if self.model.zo_eval_loss_fn_post_hooks != []:\n            for post_hook_fn in self.model.zo_eval_loss_fn_post_hooks:\n                output, input_ids, start_logits, start_positions, end_logits, end_positions = post_hook_fn(self.model, output, input_ids, start_logits, start_positions, end_logits, end_positions)\n        return output"
  },
  {
    "path": "zo2/model/huggingface/opt/mezo_sgd/zo2.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport random\nimport torch\nimport torch.nn as nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n\nfrom transformers.models.opt import modeling_opt\nfrom transformers.models.opt.modeling_opt import (\n    OPTConfig,\n    OPTPreTrainedModel,\n    OPTLearnedPositionalEmbedding,\n    OPTDecoderLayer,\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n    SequenceClassifierOutputWithPast,\n    QuestionAnsweringModelOutput,\n    add_start_docstrings_to_model_forward,\n    add_code_sample_docstrings,\n    replace_return_docstrings,\n    OPT_INPUTS_DOCSTRING,\n    _CHECKPOINT_FOR_DOC,\n    _CONFIG_FOR_DOC,\n    _EXPECTED_OUTPUT_SHAPE,\n    _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,\n    _SEQ_CLASS_EXPECTED_OUTPUT,\n    _SEQ_CLASS_EXPECTED_LOSS,\n)\nfrom transformers.utils import logging\n\nfrom typing import List, Optional, Tuple, Union\n\nfrom ....base import BaseZOModel\nfrom .....optimizer.mezo_sgd.zo2 import MeZO2SGD\nfrom .....config.mezo_sgd import MeZOSGDConfig\nfrom .utils import *\n\nlogger = logging.get_logger(__name__)\n\n\nclass OPTDecoder(modeling_opt.OPTDecoder, OPTPreTrainedModel, BaseZOModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`]\n    \n    Args:\n        config: OPTConfig\n    \"\"\"\n\n    def __init__(self, config: OPTConfig):\n        \"\"\"\n        !!! Module register must follow the execution order.\n        \"\"\"\n        OPTPreTrainedModel.__init__(self, config)\n        self.dropout = config.dropout\n        self.layerdrop = config.layerdrop\n        self.padding_idx = config.pad_token_id\n        self.max_target_positions = config.max_position_embeddings\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx)\n        self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)\n\n        if config.word_embed_proj_dim != config.hidden_size:\n            self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)\n        else:\n            self.project_in = None\n\n        self.layers = nn.ModuleList([OPTDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])\n\n        # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility\n        # with checkpoints that have been fine-tuned before transformers v4.20.1\n        # see https://github.com/facebookresearch/metaseq/pull/164\n        if config.do_layer_norm_before and not config._remove_final_layer_norm:\n            self.final_layer_norm = nn.LayerNorm(\n                config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine\n            )\n        else:\n            self.final_layer_norm = None\n\n        if config.word_embed_proj_dim != config.hidden_size:\n            self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)\n        else:\n            self.project_out = None\n\n        self.gradient_checkpointing = False\n        # Initialize weights and apply final processing\n        self.post_init()\n    \n    def zo_init(self, zo_config):\n        # Initialize ZO2\n        self.opt = OptimizerOPTDecoder(model=self, config=zo_config)\n    \n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        cache_position: Optional[torch.Tensor] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        if self.zo_training:\n            return self.opt.inner_zo_forward(input_ids, attention_mask, head_mask, \n                past_key_values, inputs_embeds, use_cache, \n                output_attentions, output_hidden_states, return_dict,\n                position_ids, cache_position)\n        else:\n            return self.opt.zo_eval_forward(input_ids, attention_mask, head_mask, \n                past_key_values, inputs_embeds, use_cache, \n                output_attentions, output_hidden_states, return_dict,\n                position_ids, cache_position)\n\n\nclass OPTModel(modeling_opt.OPTModel, OPTPreTrainedModel, BaseZOModel):\n    def __init__(self, config: OPTConfig):\n        OPTPreTrainedModel.__init__(self, config)\n        BaseZOModel.__init__(self)\n        self.decoder = OPTDecoder(config)\n        # Initialize weights and apply final processing\n        self.post_init()\n    \n    def zo_init(self, zo_config):\n        self.decoder.zo_init(zo_config)\n        # Initialize ZO2\n        self.opt = OptimizerOPTModel(model=self, config=zo_config)\n\n    @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)\n    @add_code_sample_docstrings(\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=BaseModelOutputWithPast,\n        config_class=_CONFIG_FOR_DOC,\n        expected_output=_EXPECTED_OUTPUT_SHAPE,\n    )\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        if self.zo_training:\n            return self.opt.inner_zo_forward(input_ids, attention_mask, head_mask, \n                past_key_values, inputs_embeds, use_cache, \n                output_attentions, output_hidden_states, return_dict)\n        else:\n            return self.opt.zo_eval_forward(input_ids, attention_mask, head_mask, \n                past_key_values, inputs_embeds, use_cache, \n                output_attentions, output_hidden_states, return_dict)\n\n\nclass OPTForCausalLM(modeling_opt.OPTForCausalLM, OPTPreTrainedModel, BaseZOModel):\n    _keys_to_ignore_on_load_missing = [r\"lm_head.weight\"]\n\n    def __init__(self, config: OPTConfig):\n        OPTPreTrainedModel.__init__(self, config)\n        BaseZOModel.__init__(self)\n        self.model = OPTModel(config)\n        # the lm_head weight is automatically tied to the embed tokens weight\n        self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)\n        # Initialize weights and apply final processing\n        self.post_init()\n    \n    def zo_init(self, zo_config):\n        self.model.zo_init(zo_config)\n        # Initialize ZO2\n        self.opt = OptimizerOPTForCausalLM(model=self, config=zo_config)\n\n    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        **kwargs\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you\n                provide it.\n\n                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n                [`PreTrainedTokenizer.__call__`] for details.\n\n                [What are input IDs?](../glossary#input-ids)\n            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n                - 1 for tokens that are **not masked**,\n                - 0 for tokens that are **masked**.\n\n                [What are attention masks?](../glossary#attention-mask)\n            head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):\n                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:\n\n                - 1 indicates the head is **not masked**,\n                - 0 indicates the head is **masked**.\n\n            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):\n                Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n                shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of\n                shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional\n                tensors are only required when the model is used as a decoder in a Sequence to Sequence model.\n\n                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the\n                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.\n\n                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those\n                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of\n                all `decoder_input_ids` of shape `(batch_size, sequence_length)`.\n            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.\n                This is useful if you want more control over how to convert `input_ids` indices into associated vectors\n                than the model's internal embedding lookup matrix.\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            output_hidden_states (`bool`, *optional*):\n                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors\n                for more detail.\n            return_dict (`bool`, *optional*):\n                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, OPTForCausalLM\n\n        >>> model = OPTForCausalLM.from_pretrained(\"facebook/opt-350m\")\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"facebook/opt-350m\")\n\n        >>> prompt = \"Hey, are you consciours? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you consciours? Can you talk to me?\\nI'm not consciours, but I can talk to you.\"\n        ```\"\"\"\n\n        if self.zo_training:\n            return self.opt.zo_forward(\n                input_ids, attention_mask, head_mask, \n                past_key_values, inputs_embeds, labels, use_cache, \n                output_attentions, output_hidden_states, return_dict, **kwargs)\n        else:\n            return self.opt.zo_eval_forward(\n                input_ids, attention_mask, head_mask, \n                past_key_values, inputs_embeds, labels, use_cache, \n                output_attentions, output_hidden_states, return_dict, **kwargs)\n\n\nclass OPTForSequenceClassification(modeling_opt.OPTForSequenceClassification, OPTPreTrainedModel, BaseZOModel):\n    def __init__(self, config: OPTConfig):\n        OPTPreTrainedModel.__init__(self, config)\n        BaseZOModel.__init__(self)\n        self.num_labels = config.num_labels\n        self.model = OPTModel(config)\n        self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def zo_init(self, zo_config):\n        self.model.zo_init(zo_config)\n        self.opt = OptimizerOPTForSequenceClassification(model=self, config=zo_config)\n\n    @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)\n    @add_code_sample_docstrings(\n        checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,\n        output_type=SequenceClassifierOutputWithPast,\n        config_class=_CONFIG_FOR_DOC,\n        expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,\n        expected_loss=_SEQ_CLASS_EXPECTED_LOSS,\n    )\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        **kwargs\n    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        if self.zo_training:\n            return self.opt.zo_forward(\n                input_ids, attention_mask, head_mask, \n                past_key_values, inputs_embeds, labels, use_cache, \n                output_attentions, output_hidden_states, return_dict, **kwargs)\n        else:\n            return self.opt.zo_eval_forward(\n                input_ids, attention_mask, head_mask, \n                past_key_values, inputs_embeds, labels, use_cache, \n                output_attentions, output_hidden_states, return_dict, **kwargs)\n\n\nclass OPTForQuestionAnswering(modeling_opt.OPTForQuestionAnswering, OPTPreTrainedModel, BaseZOModel):\n    def __init__(self, config: OPTConfig):\n        OPTPreTrainedModel.__init__(self, config)\n        BaseZOModel.__init__(self)\n        self.model = OPTModel(config)\n        self.qa_outputs = nn.Linear(config.word_embed_proj_dim, 2)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def zo_init(self, zo_config):\n        self.model.zo_init(zo_config)\n        self.opt = OptimizerOPTForQuestionAnswering(model=self, config=zo_config)\n\n    @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)\n    @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        start_positions: Optional[torch.LongTensor] = None,\n        end_positions: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        **kwargs\n    ) -> Union[Tuple, QuestionAnsweringModelOutput]:\n        r\"\"\"\n        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for position (index) of the start of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence\n            are not taken into account for computing the loss.\n        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for position (index) of the end of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence\n            are not taken into account for computing the loss.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, OPTForQuestionAnswering\n        >>> import torch\n\n        >>> torch.manual_seed(4)  # doctest: +IGNORE_RESULT\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"facebook/opt-350m\")\n\n        >>> # note: we are loading a OPTForQuestionAnswering from the hub here,\n        >>> # so the head will be randomly initialized, hence the predictions will be random\n        >>> model = OPTForQuestionAnswering.from_pretrained(\"facebook/opt-350m\")\n\n        >>> question, text = \"Who was Jim Henson?\", \"Jim Henson was a nice puppet\"\n\n        >>> inputs = tokenizer(question, text, return_tensors=\"pt\")\n        >>> with torch.no_grad():\n        ...     outputs = model(**inputs)\n\n        >>> answer_start_index = outputs.start_logits.argmax()\n        >>> answer_end_index = outputs.end_logits.argmax()\n\n        >>> answer_offset = len(tokenizer(question)[0])\n\n        >>> predict_answer_tokens = inputs.input_ids[\n        ...     0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1\n        ... ]\n        >>> predicted = tokenizer.decode(predict_answer_tokens)\n        >>> predicted\n        ' a nice puppet'\n        ```\"\"\"\n        if self.zo_training:\n            return self.opt.zo_forward(\n                input_ids, attention_mask, head_mask, \n                past_key_values, inputs_embeds, start_positions, end_positions, use_cache, \n                output_attentions, output_hidden_states, return_dict, **kwargs)\n        else:\n            return self.opt.zo_eval_forward(\n                input_ids, attention_mask, head_mask, \n                past_key_values, inputs_embeds, start_positions, end_positions, use_cache, \n                output_attentions, output_hidden_states, return_dict, **kwargs)\n\n\nclass OptimizerOPTDecoder(MeZO2SGD):\n\n    def init_zo2(self):\n        self.upload_stream = None\n        self.offload_stream = None\n        self.compute_stream = None\n        self.zo_random_seed = None\n        self.rstate = None\n        self.rstate_queue = None\n        self.last_rstate = None\n        self.projected_grad = None\n        self.init_zo2_upload()\n    \n    def init_zo2_upload(self):\n        self.model.embed_tokens = self.model.embed_tokens.to(self.device)\n        self.model.embed_positions = self.model.embed_positions.to(self.device)\n        if self.model.project_out:\n            self.model.project_out = self.model.project_out.to(self.device)\n        if self.model.project_in:\n            self.model.project_in = self.model.project_in.to(self.device)\n        if self.model.final_layer_norm:\n            self.model.final_layer_norm = self.model.final_layer_norm.to(self.device)\n        self.num_blocks = len(self.model.layers)\n        if self.offloading_blocks is not None:\n            self.offloading_blocks = self.offloading_blocks\n        else:\n            self.offloading_blocks = list(range(self.num_blocks))\n        print(f\"Transformer blocks {self.offloading_blocks} will be offloaded to {self.offloading_device}\")\n        for i in range(self.num_blocks):\n            if i in self.offloading_blocks:\n                continue\n            else:\n                self.model.layers[i] = self.model.layers[i].to(self.device)\n                print(f\"Upload block {i} to {self.device}.\")\n        \n    @torch.inference_mode\n    def inner_zo_forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        cache_position: Optional[torch.Tensor] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.model.config.use_cache\n\n        return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict\n\n        # retrieve input_ids and inputs_embeds\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time\")\n        elif input_ids is not None:\n            input_shape = input_ids.size()\n            input_ids = input_ids.view(-1, input_shape[-1])\n        elif inputs_embeds is not None:\n            input_shape = inputs_embeds.size()[:-1]\n        else:\n            raise ValueError(\"You have to specify either decoder_input_ids or decoder_inputs_embeds\")\n\n        if inputs_embeds is None:\n            # inputs_embeds = self.model.embed_tokens(input_ids)\n            inputs_embeds1, inputs_embeds2 = self.task_compute_module(self.model.embed_tokens,\n                                                                      inputs1={\"input\": input_ids},\n                                                                      inputs2={\"input\": input_ids},\n                                                                      grad=self.projected_grad)\n        else:\n            inputs_embeds1 = inputs_embeds2 = inputs_embeds\n\n        batch_size, seq_length = input_shape\n        # past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0\n        past_key_values_length = 0\n        # required mask seq length can be calculated via length of past\n        # mask_seq_length = past_key_values_length + seq_length\n        mask_seq_length = seq_length\n\n        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n        if cache_position is None:\n            cache_position = torch.arange(\n                past_seen_tokens, past_seen_tokens + inputs_embeds1.shape[1], device=inputs_embeds1.device\n            )\n\n        # embed positions\n        if attention_mask is None:\n            seq_length = past_seen_tokens + inputs_embeds1.shape[1]\n            attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds1.device)\n        # causal_mask = self.model._update_causal_mask(\n        #     attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions\n        # )\n        causal_attention_mask1, causal_attention_mask2 = self.task_compute_function(\n            self.model._update_causal_mask,\n            inputs1={\"attention_mask\": attention_mask, \"input_tensor\": inputs_embeds1, \"cache_position\": cache_position,\n                     \"past_key_values\": past_key_values, \"output_attentions\": output_attentions},\n            inputs2={\"attention_mask\": attention_mask, \"input_tensor\": inputs_embeds2, \"cache_position\": cache_position,\n                     \"past_key_values\": past_key_values, \"output_attentions\": output_attentions},\n            compute_sync=False\n        )\n        # pos_embeds = self.model.embed_positions(attention_mask, past_key_values_length)\n        pos_embeds1, pos_embeds2 = self.task_compute_module(self.model.embed_positions,\n                                                            inputs1={\"attention_mask\": attention_mask, \"past_key_values_length\": past_key_values_length},\n                                                            inputs2={\"attention_mask\": attention_mask, \"past_key_values_length\": past_key_values_length},\n                                                            grad=self.projected_grad,\n                                                            compute_sync=False)\n\n        if self.model.project_in is not None:\n            # inputs_embeds = self.model.project_in(inputs_embeds)\n            inputs_embeds1, inputs_embeds2 = self.task_compute_module(self.model.project_in,\n                                                                      inputs1={\"input\": inputs_embeds1},\n                                                                      inputs2={\"input\": inputs_embeds2},\n                                                                      grad=self.projected_grad,\n                                                                      compute_sync=False)\n\n        # hidden_states = inputs_embeds + pos_embeds\n        hidden_states1, hidden_states2 = self.task_compute_function(torch.add,\n                                                                    inputs1={\"input\": inputs_embeds1, \"other\": pos_embeds1},\n                                                                    inputs2={\"input\": inputs_embeds2, \"other\": pos_embeds2},\n                                                                    compute_sync=False)\n\n        if 0 in self.offloading_blocks:\n            self.model.layers[0] = self.task_upload(\n                module=self.model.layers[0],\n                device=self.device\n            )\n\n        # if self.model.gradient_checkpointing and self.model.training:\n        #     if use_cache:\n        #         logger.warning_once(\n        #             \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n        #         )\n        #         use_cache = False\n\n        # # decoder layers\n        # all_hidden_states = () if output_hidden_states else None\n        # all_self_attns = () if output_attentions else None\n        # next_decoder_cache = () if use_cache else None\n\n        # check if head_mask has a correct number of layers specified if desired\n        for attn_mask, mask_name in zip([head_mask], [\"head_mask\"]):\n            if attn_mask is not None:\n                if attn_mask.size()[0] != (len(self.model.layers)):\n                    raise ValueError(\n                        f\"The `{mask_name}` should be specified for {len(self.model.layers)} layers, but it is for\"\n                        f\" {head_mask.size()[0]}.\"\n                    )\n\n        N = len(self.model.layers)\n        for i in range(1, N):\n\n            if i != 1:\n                if i-2 in self.offloading_blocks:\n                    self.model.layers[i-2] = self.task_offload(\n                        module=self.model.layers[i-2],\n                        device=self.offloading_device)\n                    \n            layer_outputs1, layer_outputs2 = self.task_compute_module(\n                self.model.layers[i-1],\n                inputs1={\"hidden_states\": hidden_states1, \"attention_mask\": causal_attention_mask1, \n                            \"layer_head_mask\": (head_mask[i-1] if head_mask is not None else None),\n                            \"output_attentions\": output_attentions},\n                inputs2={\"hidden_states\": hidden_states2, \"attention_mask\": causal_attention_mask2, \n                            \"layer_head_mask\": (head_mask[i-1] if head_mask is not None else None),\n                            \"output_attentions\": output_attentions},\n                grad=self.projected_grad)\n\n            # hidden_states = layer_outputs[0]\n            hidden_states1, hidden_states2 = self.task_compute_function(\n                fn=fn_get_opt_decoder_hidden_states_from_layer_outputs,\n                inputs1={\"input\": layer_outputs1},\n                inputs2={\"input\": layer_outputs2},\n                compute_sync=False\n            )\n            \n            if i in self.offloading_blocks:\n                self.model.layers[i] = self.task_upload(\n                    module=self.model.layers[i],\n                    device=self.device)\n\n        if N-2 in self.offloading_blocks:\n            self.model.layers[N-2] = self.task_offload(\n                module=self.model.layers[N-2],\n                device=self.offloading_device)\n        \n        layer_outputs1, layer_outputs2 = self.task_compute_module(\n            self.model.layers[N-1],\n            inputs1={\"hidden_states\": hidden_states1, \"attention_mask\": causal_attention_mask1, \n                        \"layer_head_mask\": (head_mask[i-1] if head_mask is not None else None),\n                        \"output_attentions\": output_attentions},\n            inputs2={\"hidden_states\": hidden_states2, \"attention_mask\": causal_attention_mask2, \n                        \"layer_head_mask\": (head_mask[i-1] if head_mask is not None else None),\n                        \"output_attentions\": output_attentions},\n            grad=self.projected_grad)\n\n        hidden_states1, hidden_states2 = self.task_compute_function(\n            fn=fn_get_opt_decoder_hidden_states_from_layer_outputs,\n            inputs1={\"input\": layer_outputs1},\n            inputs2={\"input\": layer_outputs2},\n            compute_sync=False\n        )\n\n        if N-1 in self.offloading_blocks:\n            self.model.layers[N-1] = self.task_offload(\n                module=self.model.layers[N-1],\n                device=self.offloading_device)\n            \n        if self.model.final_layer_norm is not None:\n            # hidden_states = self.model.final_layer_norm(hidden_states)\n            hidden_states1, hidden_states2 = self.task_compute_module(\n                module=self.model.final_layer_norm,\n                inputs1={\"input\": hidden_states1},\n                inputs2={\"input\": hidden_states2},\n                grad=self.projected_grad,\n                weight_decay=0.)\n\n        if self.model.project_out is not None:\n            # hidden_states = self.model.project_out(hidden_states)\n            hidden_states1, hidden_states2 = self.task_compute_module(\n                module=self.model.project_out,\n                inputs1={\"input\": hidden_states1},\n                inputs2={\"input\": hidden_states2},\n                grad=self.projected_grad,\n                compute_sync=False)\n\n        return hidden_states1, hidden_states2\n\n    @torch.inference_mode\n    def inner_zo_eval_forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        cache_position: Optional[torch.Tensor] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.model.config.use_cache\n\n        return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict\n\n        # retrieve input_ids and inputs_embeds\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time\")\n        elif input_ids is not None:\n            input_shape = input_ids.size()\n            input_ids = input_ids.view(-1, input_shape[-1])\n        elif inputs_embeds is not None:\n            input_shape = inputs_embeds.size()[:-1]\n        else:\n            raise ValueError(\"You have to specify either decoder_input_ids or decoder_inputs_embeds\")\n\n        if inputs_embeds is None:\n            # inputs_embeds = self.model.embed_tokens(input_ids)\n            inputs_embeds = self.task_compute_module(self.model.embed_tokens, \n                                                     inputs1={\"input\": input_ids},\n                                                     inputs2=None,\n                                                     grad=None)\n\n        batch_size, seq_length = input_shape\n        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0\n        # required mask seq length can be calculated via length of past\n        mask_seq_length = past_key_values_length + seq_length\n        \n        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n        if cache_position is None:\n            cache_position = torch.arange(\n                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device\n            )\n\n        # embed positions\n        if attention_mask is None:\n            attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)\n        # causal_mask = self.model._update_causal_mask(\n        #     attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions\n        # )\n        causal_attention_mask = self.task_compute_function(\n            self.model._update_causal_mask,\n            inputs1={\"attention_mask\": attention_mask, \"input_tensor\": inputs_embeds, \"cache_position\": cache_position, \n                     \"past_key_values\": past_key_values, \"output_attentions\": output_attentions},\n            inputs2=None\n        )\n        # pos_embeds = self.model.embed_positions(attention_mask, past_key_values_length)\n        pos_embeds = self.task_compute_module(self.model.embed_positions,\n                                            inputs1={\"attention_mask\": attention_mask, \"past_key_values_length\": past_key_values_length},\n                                            inputs2=None,\n                                            grad=None,\n                                            compute_sync=False)\n\n        if self.model.project_in is not None:\n            # inputs_embeds = self.model.project_in(inputs_embeds)\n            inputs_embeds = self.task_compute_module(self.model.project_in,\n                                                    inputs1={\"input\": inputs_embeds},\n                                                    inputs2=None,\n                                                    grad=None,\n                                                    compute_sync=False)\n\n        # hidden_states = inputs_embeds + pos_embeds\n        hidden_states = self.task_compute_function(torch.add,\n                                                inputs1={\"input\": inputs_embeds, \"other\": pos_embeds},\n                                                inputs2=None,\n                                                compute_sync=False)\n        \n        if self.model.gradient_checkpointing and self.model.training:\n            if use_cache:\n                logger.warning_once(\n                    \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                )\n                use_cache = False\n        \n        # decoder layers\n        # all_hidden_states = () if output_hidden_states else None\n        # all_self_attns = () if output_attentions else None\n        # next_decoder_cache = () if use_cache else None\n        all_hidden_states = self.task_compute_function(init_all_hidden_states,\n                                                       inputs1={\"output_hidden_states\": output_hidden_states},\n                                                       inputs2=None,\n                                                       compute_sync=False)\n        all_self_attns = self.task_compute_function(init_all_self_attns,\n                                                    inputs1={\"output_attentions\": output_attentions},\n                                                    inputs2=None,\n                                                    compute_sync=False)\n        next_decoder_cache = self.task_compute_function(init_next_decoder_cache,\n                                                        inputs1={\"use_cache\": use_cache},\n                                                        inputs2=None,\n                                                        compute_sync=False)\n\n        if 0 in self.offloading_blocks:\n            self.model.layers[0] = self.task_upload(\n                module=self.model.layers[0],\n                device=self.device\n            )\n\n        # check if head_mask has a correct number of layers specified if desired\n        for attn_mask, mask_name in zip([head_mask], [\"head_mask\"]):\n            if attn_mask is not None:\n                if attn_mask.size()[0] != (len(self.model.layers)):\n                    raise ValueError(\n                        f\"The `{mask_name}` should be specified for {len(self.model.layers)} layers, but it is for\"\n                        f\" {head_mask.size()[0]}.\"\n                    )\n\n        N = len(self.model.layers)\n        for i in range(1, N):\n\n            if i != 1:\n                if i-2 in self.offloading_blocks:\n                    self.model.layers[i-2] = self.task_offload(\n                        module=self.model.layers[i-2],\n                        device=self.offloading_device)\n            \n            all_hidden_states = self.task_compute_function(\n                fn=update_all_hidden_states,\n                inputs1={\"output_hidden_states\": output_hidden_states, \"all_hidden_states\": all_hidden_states, \"hidden_states\": hidden_states},\n                inputs2=None,\n                compute_sync=False)\n\n            past_key_value = self.task_compute_function(\n                fn=get_past_key_value,\n                inputs1={\"past_key_values\": past_key_values, \"idx\": i},\n                inputs2=None,\n                compute_sync=False)\n\n            layer_outputs = self.task_compute_module(\n                self.model.layers[i-1],\n                inputs1={\"hidden_states\": hidden_states, \"attention_mask\": causal_attention_mask, \n                            \"layer_head_mask\": (head_mask[i-1] if head_mask is not None else None),\n                            \"past_key_value\": past_key_value,\n                            \"output_attentions\": output_attentions,\n                            \"use_cache\": use_cache},\n                inputs2=None,\n                grad=None)\n        \n            # hidden_states = layer_outputs[0]\n            hidden_states = self.task_compute_function(\n                fn=fn_get_opt_decoder_hidden_states_from_layer_outputs,\n                inputs1={\"input\": layer_outputs},\n                inputs2=None,\n                compute_sync=False)\n            \n            next_decoder_cache = self.task_compute_function(\n                fn=update_next_decoder_cache,\n                inputs1={\"use_cache\": use_cache, \"next_decoder_cache\": next_decoder_cache, \"layer_outputs\": layer_outputs, \"output_attentions\": output_attentions},\n                inputs2=None,\n                compute_sync=False)\n\n            all_self_attns = self.task_compute_function(\n                fn=update_all_self_attns,\n                inputs1={\"output_attentions\": output_attentions, \"all_self_attns\": all_self_attns, \"layer_outputs\": layer_outputs},\n                inputs2=None,\n                compute_sync=False)\n            \n            # an unknown bug here, need to synchronize the stream to avoid memory leak (only apears in opt-350m)\n            if i in range(1, N-1, 2) and i in self.offloading_blocks:\n                self.compute_stream.synchronize()   # a weird but useful trick to avoid memory leak\n\n            if i in self.offloading_blocks:\n                self.model.layers[i] = self.task_upload(\n                    module=self.model.layers[i],\n                    device=self.device)\n\n        if N-2 in self.offloading_blocks:\n            self.model.layers[N-2] = self.task_offload(\n                module=self.model.layers[N-2],\n                device=self.offloading_device)\n        \n        all_hidden_states = self.task_compute_function(\n            fn=update_all_hidden_states,\n            inputs1={\"output_hidden_states\": output_hidden_states, \"all_hidden_states\": all_hidden_states, \"hidden_states\": hidden_states},\n            inputs2=None)\n\n        layer_outputs = self.task_compute_module(\n            self.model.layers[N-1],\n            inputs1={\"hidden_states\": hidden_states, \"attention_mask\": causal_attention_mask, \n                        \"layer_head_mask\": (head_mask[i-1] if head_mask is not None else None),\n                        \"past_key_value\": past_key_value,\n                        \"output_attentions\": output_attentions,\n                        \"use_cache\": use_cache},\n            inputs2=None,\n            grad=None)\n\n        hidden_states = self.task_compute_function(\n            fn=fn_get_opt_decoder_hidden_states_from_layer_outputs,\n            inputs1={\"input\": layer_outputs},\n            inputs2=None,\n            compute_sync=False)\n\n        next_decoder_cache = self.task_compute_function(\n            fn=update_next_decoder_cache,\n            inputs1={\"use_cache\": use_cache, \"next_decoder_cache\": next_decoder_cache, \"layer_outputs\": layer_outputs, \"output_attentions\": output_attentions},\n            inputs2=None,\n            compute_sync=False)\n\n        all_self_attns = self.task_compute_function(\n            fn=update_all_self_attns,\n            inputs1={\"output_attentions\": output_attentions, \"all_self_attns\": all_self_attns, \"layer_outputs\": layer_outputs},\n            inputs2=None,\n            compute_sync=False\n        )\n            \n        if N-1 in self.offloading_blocks:\n            self.model.layers[N-1] = self.task_offload(\n                module=self.model.layers[N-1],\n                device=self.offloading_device)\n            \n        if self.model.final_layer_norm is not None:\n            # hidden_states = self.model.final_layer_norm(hidden_states)\n            hidden_states = self.task_compute_module(\n                module=self.model.final_layer_norm,\n                inputs1={\"input\": hidden_states},\n                inputs2=None,\n                grad=None)\n\n        if self.model.project_out is not None:\n            # hidden_states = self.model.project_out(hidden_states)\n            hidden_states = self.task_compute_module(\n                module=self.model.project_out,\n                inputs1={\"input\": hidden_states},\n                inputs2=None,\n                grad=None,\n                compute_sync=False)\n\n        # add hidden states from the last decoder layer\n        # if output_hidden_states:\n        #     all_hidden_states += (hidden_states,)\n        all_hidden_states = self.task_compute_function(\n            fn=update_all_hidden_states,\n            inputs1={\"output_hidden_states\": output_hidden_states, \"all_hidden_states\": all_hidden_states, \"hidden_states\": hidden_states},\n            inputs2=None,\n            compute_sync=False\n        )\n\n        next_cache = next_decoder_cache if use_cache else None\n        if not return_dict:\n            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n\n\nclass OptimizerOPTModel(MeZO2SGD):\n\n    def init_zo2(self):\n        self.upload_stream = None\n        self.offload_stream = None\n        self.compute_stream = None\n        self.zo_random_seed = None\n        self.rstate = None\n        self.rstate_queue = None\n        self.last_rstate = None\n        self.projected_grad = None\n        self.init_zo2_upload()\n    \n    def init_zo2_upload(self):\n        ...\n    \n    @torch.inference_mode\n    def inner_zo_forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.model.config.use_cache\n        return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict\n\n        # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)\n        self.model.decoder.zo_training = True\n        self.assign_zo2_attributes(self, self.model.decoder.opt)\n        output = self.model.decoder(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        self.assign_zo2_attributes(self.model.decoder.opt, self)\n        \n        return output\n\n    @torch.inference_mode\n    def inner_zo_eval_forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.model.config.use_cache\n        return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict\n\n        self.model.decoder.zo_training = False\n        self.assign_zo2_attributes(self, self.model.decoder.opt)\n        # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)\n        decoder_outputs = self.model.decoder(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        self.assign_zo2_attributes(self.model.decoder.opt, self)\n\n        if not return_dict:\n            return decoder_outputs\n\n        return BaseModelOutputWithPast(\n            last_hidden_state=decoder_outputs.last_hidden_state,\n            past_key_values=decoder_outputs.past_key_values,\n            hidden_states=decoder_outputs.hidden_states,\n            attentions=decoder_outputs.attentions,\n        )\n\n\nclass OptimizerOPTForCausalLM(MeZO2SGD):\n\n    def init_zo2_upload(self):\n        self.model.lm_head = self.model.lm_head.to(self.device)\n    \n    @torch.inference_mode\n    def inner_zo_forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        **kwargs\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        \"\"\"\n            copy the original forward code and replace all 'self' to 'self.model'.\n        \"\"\"\n\n        output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        self.model.model.decoder.zo_training = True\n        self.assign_zo2_attributes(self, self.model.model.decoder.opt)\n        hidden_states1, hidden_states2 = self.model.model.decoder(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        self.assign_zo2_attributes(self.model.model.decoder.opt, self)\n\n        # logits = self.model.lm_head(outputs[0]).contiguous()\n        logits1, logits2 = self.task_compute_module(self.model.lm_head,\n                                                    inputs1={\"input\": hidden_states1},\n                                                    inputs2={\"input\": hidden_states2},\n                                                    grad=self.projected_grad)\n\n        if self.model.zo_train_loss_fn_pre_hooks != []:\n            for pre_hook_fn in self.model.zo_train_loss_fn_pre_hooks:\n                (input_ids, logits1, labels), (input_ids, logits2, labels) = \\\n                    self.task_compute_function(pre_hook_fn,\n                        inputs1={\"self\": self.model, \"input_ids\": input_ids, \"logits\": logits1, \"labels\": labels},\n                        inputs2={\"self\": self.model, \"input_ids\": input_ids, \"logits\": logits2, \"labels\": labels})\n        \n        # loss = None\n        if self.model.zo_custom_train_loss_fn:\n            loss1, loss2 = self.task_compute_function(self.model.zo_custom_train_loss_fn,\n                inputs1={\"self\": self.model, \"input_ids\": input_ids, \"logits\": logits1, \"labels\": labels, **kwargs},\n                inputs2={\"self\": self.model, \"input_ids\": input_ids, \"logits\": logits2, \"labels\": labels, **kwargs})\n        elif labels is not None:\n            # Shift so that tokens < n predict n\n            # shift_logits = logits[..., :-1, :].contiguous()\n            shift_logits1, shift_logits2 = self.task_compute_function(\n                fn=get_shift_logits,\n                inputs1={\"logits\": logits1},\n                inputs2={\"logits\": logits2})\n            # shift_labels = labels[..., 1:].contiguous()\n            shift_labels1, shift_labels2 = self.task_compute_function(\n                fn=get_shift_labels,\n                inputs1={\"labels\": labels},\n                inputs2={\"labels\": labels})\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            # loss = loss_fct(shift_logits.view(-1, self.model.config.vocab_size), shift_labels.view(-1))\n            loss1, loss2 = self.task_compute_function(\n                fn=loss_fct,\n                inputs1={\"input\": shift_logits1.view(-1, self.model.config.vocab_size), \"target\": shift_labels1.view(-1)},\n                inputs2={\"input\": shift_logits2.view(-1, self.model.config.vocab_size), \"target\": shift_labels2.view(-1)})\n\n        if self.model.zo_train_loss_fn_post_hooks != []:\n            for post_hook_fn in self.model.zo_train_loss_fn_post_hooks:\n                (loss1, input_ids, logits1, labels), (loss2, input_ids, logits2, labels) = \\\n                    self.task_compute_function(post_hook_fn,\n                        inputs1={\"self\": self.model, \"loss\": loss1, \"input_ids\": input_ids, \"logits\": logits1, \"labels\": labels},\n                        inputs2={\"self\": self.model, \"loss\": loss2, \"input_ids\": input_ids, \"logits\": logits2, \"labels\": labels})\n\n        return loss1, loss2\n\n    @torch.inference_mode\n    def inner_zo_eval_forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        **kwargs\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict\n\n        self.model.model.decoder.zo_training = False\n        self.assign_zo2_attributes(self, self.model.model.decoder.opt)\n        outputs = self.model.model.decoder(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        self.assign_zo2_attributes(self.model.model.decoder.opt, self)\n\n        hidden_states = self.task_compute_function(\n            fn_get_opt_decoder_hidden_states_from_layer_outputs,\n            inputs1={\"input\": outputs},\n            inputs2=None,\n            compute_sync=False\n        )\n        \n        logits = self.task_compute_module(self.model.lm_head,\n                                        inputs1={\"input\": hidden_states},\n                                        inputs2=None,\n                                        grad=self.projected_grad)\n\n        if self.model.zo_eval_loss_fn_pre_hooks != []:\n            for pre_hook_fn in self.model.zo_eval_loss_fn_pre_hooks:\n                input_ids, logits, labels = \\\n                    self.task_compute_function(pre_hook_fn,\n                        inputs1=([self.model], {\"input_ids\": input_ids, \"logits\": logits, \"labels\": labels}),\n                        inputs2=None)\n        \n        loss = None\n        if self.model.zo_custom_eval_loss_fn:\n            loss = self.task_compute_function(\n                fn=self.model.zo_custom_eval_loss_fn,\n                inputs1=([self.model], {\"input_ids\": input_ids, \"logits\": logits, \"labels\": labels, **kwargs}),\n                inputs2=None,\n                compute_sync=False\n            )\n        elif labels is not None:\n            # Shift so that tokens < n predict n\n            # shift_logits = logits[..., :-1, :].contiguous()\n            shift_logits = self.task_compute_function(\n                fn=get_shift_logits,\n                inputs1={\"logits\": logits},\n                inputs2=None)\n            # shift_labels = labels[..., 1:].contiguous()\n            shift_labels = self.task_compute_function(\n                fn=get_shift_labels,\n                inputs1={\"labels\": labels},\n                inputs2=None)\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            # loss = loss_fct(shift_logits.view(-1, self.model.config.vocab_size), shift_labels.view(-1))\n            loss = self.task_compute_function(\n                fn=loss_fct,\n                inputs1={\"input\": shift_logits.view(-1, self.model.config.vocab_size), \"target\": shift_labels.view(-1)},\n                inputs2=None)\n        \n        if self.model.zo_eval_loss_fn_post_hooks != []:\n            for post_hook_fn in self.model.zo_eval_loss_fn_post_hooks:\n                output, input_ids, logits, labels = \\\n                    self.task_compute_function(post_hook_fn,\n                        inputs1=([self.model], {\"loss\": loss, \"input_ids\": input_ids, \"logits\": logits, \"labels\": labels}),\n                        inputs2=None)\n        \n        if not return_dict:\n            output = (logits,) + outputs[1]\n            return (loss,) + output if loss is not None else output\n        \n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\nclass OptimizerOPTForSequenceClassification(MeZO2SGD):\n\n    def init_zo2_upload(self):\n        self.model.score = self.model.score.to(self.device)\n    \n    @torch.inference_mode\n    def inner_zo_forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        **kwargs\n    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:\n        \"\"\"\n            copy the original forward code and replace all 'self' to 'self.model'.\n        \"\"\"\n\n        return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict\n\n        self.model.model.decoder.zo_training = True\n        self.assign_zo2_attributes(self, self.model.model.opt)\n        hidden_states1, hidden_states2 = self.model.model(\n            input_ids,\n            past_key_values=past_key_values,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        self.assign_zo2_attributes(self.model.model.opt, self)\n        \n        # hidden_states = transformer_outputs[0]\n        # logits = self.model.score(hidden_states)\n        logits1, logits2 = self.task_compute_module(self.model.score,\n                                                    inputs1={\"input\": hidden_states1},\n                                                    inputs2={\"input\": hidden_states2},\n                                                    grad=self.projected_grad)\n\n        if input_ids is not None:\n            batch_size, sequence_length = input_ids.shape[:2]\n        else:\n            batch_size, sequence_length = inputs_embeds.shape[:2]\n\n        if self.model.config.pad_token_id is None:\n            sequence_lengths = -1\n        else:\n            if input_ids is not None:\n                sequence_lengths = (torch.ne(input_ids, self.model.config.pad_token_id).sum(-1) - 1).to(logits1.device)\n            else:\n                sequence_lengths = -1\n                logger.warning(\n                    f\"{self.model.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be \"\n                    \"unexpected if using padding tokens in conjunction with `inputs_embeds.`\"\n                )\n\n        # pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]\n        pooled_logits1, pooled_logits2 = self.task_compute_function(\n            fn=get_pooled_logits,\n            inputs1={\"logits\": logits1, \"batch_size\": batch_size, \"sequence_lengths\": sequence_lengths},\n            inputs2={\"logits\": logits2, \"batch_size\": batch_size, \"sequence_lengths\": sequence_lengths},)\n\n        if self.model.zo_train_loss_fn_pre_hooks != []:\n            for pre_hook_fn in self.model.zo_train_loss_fn_pre_hooks:\n                (input_ids, pooled_logits1, labels), (input_ids, pooled_logits2, labels) = \\\n                    self.task_compute_function(pre_hook_fn,\n                        inputs1={\"self\": self, \"input_ids\": input_ids, \"logits\": pooled_logits1, \"labels\": labels},\n                        inputs2={\"self\": self, \"input_ids\": input_ids, \"logits\": pooled_logits2, \"labels\": labels})\n        \n        # loss = None\n        if self.model.zo_custom_train_loss_fn:\n            loss1, loss2 = self.task_compute_function(self.model.zo_custom_train_loss_fn,\n                inputs1={\"self\": self.model, \"input_ids\": input_ids, \"logits\": pooled_logits1, \"labels\": labels, **kwargs},\n                inputs2={\"self\": self.model, \"input_ids\": input_ids, \"logits\": pooled_logits2, \"labels\": labels, **kwargs})\n        elif labels is not None:\n            if self.model.config.problem_type is None:\n                if self.model.num_labels == 1:\n                    self.model.config.problem_type = \"regression\"\n                elif self.model.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n                    self.model.config.problem_type = \"single_label_classification\"\n                else:\n                    self.model.config.problem_type = \"multi_label_classification\"\n\n            if self.model.config.problem_type == \"regression\":\n                loss_fct = MSELoss()\n                if self.model.num_labels == 1:\n                    # loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())\n                    loss1, loss2 = self.task_compute_function(\n                        fn=loss_fct,\n                        inputs1={\"input\": pooled_logits1.squeeze(), \"target\": labels.squeeze()},\n                        inputs2={\"input\": pooled_logits2.squeeze(), \"target\": labels.squeeze()},)\n                else:\n                    # loss = loss_fct(pooled_logits, labels)\n                    loss1, loss2 = self.task_compute_function(\n                        fn=loss_fct,\n                        inputs1={\"input\": pooled_logits1, \"target\": labels},\n                        inputs2={\"input\": pooled_logits2, \"target\": labels},)\n            elif self.model.config.problem_type == \"single_label_classification\":\n                loss_fct = CrossEntropyLoss()\n                # loss = loss_fct(pooled_logits.view(-1, self.model.num_labels), labels.view(-1))\n                loss1, loss2 = self.task_compute_function(\n                    fn=loss_fct,\n                    inputs1={\"input\": pooled_logits1.view(-1, self.model.num_labels), \"target\": labels.view(-1)},\n                    inputs2={\"input\": pooled_logits2.view(-1, self.model.num_labels), \"target\": labels.view(-1)},)\n            elif self.model.config.problem_type == \"multi_label_classification\":\n                loss_fct = BCEWithLogitsLoss()\n                # loss = loss_fct(pooled_logits, labels)\n                loss1, loss2 = self.task_compute_function(\n                    fn=loss_fct,\n                    inputs1={\"input\": pooled_logits1, \"target\": labels},\n                    inputs2={\"input\": pooled_logits2, \"target\": labels},)\n        \n        if self.model.zo_train_loss_fn_post_hooks != []:\n            for post_hook_fn in self.model.zo_train_loss_fn_post_hooks:\n                (loss1, input_ids, pooled_logits1, labels), (loss2, input_ids, pooled_logits2, labels) = \\\n                    self.task_compute_function(post_hook_fn,\n                        inputs1={\"self\": self.model, \"loss\": loss1, \"input_ids\": input_ids, \"logits\": pooled_logits1, \"labels\": labels},\n                        inputs2={\"self\": self.model, \"loss\": loss2, \"input_ids\": input_ids, \"logits\": pooled_logits2, \"labels\": labels})\n\n        return loss1, loss2\n        \n    @torch.inference_mode\n    def inner_zo_eval_forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        **kwargs\n    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.model.model.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.model.model.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.model.model.config.use_return_dict\n\n        self.model.model.zo_training = False\n        self.assign_zo2_attributes(self, self.model.model.opt)\n        transformer_outputs = self.model.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        self.assign_zo2_attributes(self.model.model.opt, self)\n\n        hidden_states = self.task_compute_function(\n            fn=fn_get_opt_decoder_hidden_states_from_layer_outputs,\n            inputs1={\"input\": transformer_outputs},\n            inputs2=None)\n\n        logits = self.task_compute_module(self.model.score,\n                                        inputs1={\"input\": hidden_states},\n                                        inputs2=None,\n                                        grad=self.projected_grad)\n\n        pooled_logits = self.task_compute_function(\n            fn=get_opt_sequence_classification_pooled_logits,\n            inputs1=([self.model], {\"logits\": logits, \"input_ids\": input_ids, \"inputs_embeds\": inputs_embeds}),\n            inputs2=None,\n            compute_sync=False)\n\n        if self.model.zo_eval_loss_fn_pre_hooks != []:\n            for pre_hook_fn in self.model.zo_eval_loss_fn_pre_hooks:\n                input_ids, logits, labels = self.task_compute_function(pre_hook_fn,\n                    inputs1=([self.model], {\"input_ids\": input_ids, \"logits\": logits, \"labels\": labels}),\n                    inputs2=None,\n                    compute_sync=False)\n\n        loss = None\n        if self.model.zo_custom_eval_loss_fn:\n            loss = self.task_compute_function(\n                fn=self.model.zo_custom_eval_loss_fn,\n                inputs1=([self.model], {\"input_ids\": input_ids, \"pooled_logits\": pooled_logits, \"labels\": labels, **kwargs}),\n                inputs2=None,\n                compute_sync=False\n            )\n        elif labels is not None:\n            loss = self.task_compute_function(\n                fn=get_opt_sequence_classification_loss,\n                inputs1=([self.model], {\"loss\": loss, \"pooled_logits\": pooled_logits, \"labels\": labels}),\n                inputs2=None,\n                compute_sync=False\n            )\n        \n        if self.model.zo_eval_loss_fn_post_hooks != []:\n            for post_hook_fn in self.model.zo_eval_loss_fn_post_hooks:\n                transformer_outputs, input_ids, logits, labels = self.task_compute_function(post_hook_fn,\n                    inputs1=([self.model], {\"transformer_outputs\": transformer_outputs, \"input_ids\": input_ids, \"pooled_logits\": pooled_logits, \"labels\": labels}),\n                    inputs2=None,\n                    compute_sync=False)\n\n        if not return_dict:\n            transformer_outputs = (logits,) + transformer_outputs[1:]\n            return ((loss,) + transformer_outputs) if loss is not None else transformer_outputs\n        \n        return SequenceClassifierOutputWithPast(\n            loss=loss,\n            logits=pooled_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n\n\nclass OptimizerOPTForQuestionAnswering(MeZO2SGD):\n    \n    def init_zo2_upload(self):\n        self.model.qa_outputs = self.model.qa_outputs.to(self.device)\n    \n    @torch.inference_mode\n    def inner_zo_forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        start_positions: Optional[torch.LongTensor] = None,\n        end_positions: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        **kwargs\n    ) -> Union[Tuple, QuestionAnsweringModelOutput]:\n        \"\"\"\n            copy the original forward code and replace all 'self' to 'self.model'.\n        \"\"\"\n        \n        return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict\n\n        self.model.model.decoder.zo_training = True\n        self.assign_zo2_attributes(self, self.model.model.opt)\n        hidden_states1, hidden_states2 = self.model.model(\n            input_ids,\n            past_key_values=past_key_values,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        self.assign_zo2_attributes(self.model.model.opt, self)\n        \n        # hidden_states = transformer_outputs[0]\n\n        # logits = self.model.qa_outputs(hidden_states)\n        logits1, logits2 = self.task_compute_module(self.model.qa_outputs,\n                                                    inputs1={\"input\": hidden_states1},\n                                                    inputs2={\"input\": hidden_states2},\n                                                    grad=self.projected_grad)\n        # start_logits, end_logits = logits.split(1, dim=-1)\n        # start_logits = start_logits.squeeze(-1).contiguous()\n        # end_logits = end_logits.squeeze(-1).contiguous()\n        (start_logits1, end_logits1), (start_logits2, end_logits2) = self.task_compute_function(\n            fn=get_start_logits_and_end_logits,\n            inputs1={\"logits\": logits1},\n            inputs2={\"logits\": logits2},)\n\n        if self.model.zo_train_loss_fn_pre_hooks != []:\n            for pre_hook_fn in self.model.zo_train_loss_fn_pre_hooks:\n                (input_ids, start_logits1, start_positions, end_logits1, end_positions), (input_ids, start_logits2, start_positions, end_logits2, end_positions) = \\\n                    self.task_compute_function(pre_hook_fn,\n                        inputs1={\"self\": self, \"input_ids\": input_ids, \"start_logits\": start_logits1, \"start_positions\": start_positions, \"end_logits\": end_logits1, \"end_positions\": end_positions},\n                        inputs2={\"self\": self, \"input_ids\": input_ids, \"start_logits\": start_logits2, \"start_positions\": start_positions, \"end_logits\": end_logits2, \"end_positions\": end_positions})\n        \n        # total_loss = None\n        if self.model.zo_custom_train_loss_fn:\n            loss1, loss2 = self.task_compute_function(self.model.zo_custom_train_loss_fn,\n                inputs1={\"self\": self.model, \"input_ids\": input_ids, \"start_logits\": start_logits1, \"start_positions\": start_positions, \"end_logits\": end_logits1, \"end_positions\": end_positions, **kwargs},\n                inputs2={\"self\": self.model, \"input_ids\": input_ids, \"start_logits\": start_logits2, \"start_positions\": start_positions, \"end_logits\": end_logits2, \"end_positions\": end_positions, **kwargs})\n        elif start_positions is not None and end_positions is not None:\n            # If we are on multi-GPU, split add a dimension\n            if len(start_positions.size()) > 1:\n                start_positions = start_positions.squeeze(-1)\n            if len(end_positions.size()) > 1:\n                end_positions = end_positions.squeeze(-1)\n            # sometimes the start/end positions are outside our model inputs, we ignore these terms\n            ignored_index = start_logits1.size(1)\n            start_positions = start_positions.clamp(0, ignored_index)\n            end_positions = end_positions.clamp(0, ignored_index)\n            \n            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)\n            # start_loss = loss_fct(start_logits, start_positions)\n            # end_loss = loss_fct(end_logits, end_positions)\n            # total_loss = (start_loss + end_loss) / 2\n            loss1, loss2 = self.task_compute_function(\n                fn=get_qa_loss,\n                inputs1={\"loss_fct\": loss_fct, \"start_logits\": start_logits1, \"start_positions\": start_positions, \"end_logits\": end_logits1, \"end_positions\": end_positions},\n                inputs2={\"loss_fct\": loss_fct, \"start_logits\": start_logits2, \"start_positions\": start_positions, \"end_logits\": end_logits2, \"end_positions\": end_positions})\n\n        if self.model.zo_train_loss_fn_post_hooks != []:\n            for post_hook_fn in self.model.zo_train_loss_fn_post_hooks:\n                (loss1, input_ids, start_logits1, start_positions, end_logits1, end_positions), (loss2, input_ids, start_logits2, start_positions, end_logits2, end_positions) = \\\n                    self.task_compute_function(post_hook_fn,\n                        inputs1={\"self\": self.model, \"loss\": loss1, \"input_ids\": input_ids, \"start_logits\": start_logits1, \"start_positions\": start_positions, \"end_logits\": end_logits1, \"end_positions\": end_positions},\n                        inputs2={\"self\": self.model, \"loss\": loss2, \"input_ids\": input_ids, \"start_logits\": start_logits2, \"start_positions\": start_positions, \"end_logits\": end_logits2, \"end_positions\": end_positions})\n\n        return loss1, loss2\n        \n    @torch.inference_mode\n    def inner_zo_eval_forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        start_positions: Optional[torch.LongTensor] = None,\n        end_positions: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        **kwargs\n    ) -> Union[Tuple, QuestionAnsweringModelOutput]:\n        return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict\n\n        self.model.model.zo_training = False\n        self.assign_zo2_attributes(self, self.model.model.opt)\n        transformer_outputs = self.model.model(\n            input_ids,\n            past_key_values=past_key_values,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        self.assign_zo2_attributes(self.model.model.opt, self)\n        \n        hidden_states = self.task_compute_function(\n            fn=fn_get_opt_decoder_hidden_states_from_layer_outputs,\n            inputs1={\"input\": transformer_outputs},\n            inputs2=None)\n\n        logits = self.task_compute_module(self.model.qa_outputs,\n                                        inputs1={\"input\": hidden_states},\n                                        inputs2=None,\n                                        grad=self.projected_grad)\n\n        start_logits, end_logits = self.task_compute_function(\n            fn=get_start_logits_and_end_logits,\n            inputs1={\"logits\": logits},\n            inputs2=None,\n            compute_sync=False)\n\n        if self.model.zo_eval_loss_fn_pre_hooks != []:\n            for pre_hook_fn in self.model.zo_eval_loss_fn_pre_hooks:\n                input_ids, start_logits, start_positions, end_logits, end_positions = self.task_compute_function(pre_hook_fn,\n                    inputs1=([self.model], {\"input_ids\": input_ids, \"start_logits\": start_logits, \"start_positions\": start_positions, \"end_logits\": end_logits, \"end_positions\": end_positions}),\n                    inputs2=None,\n                    compute_sync=False)\n\n        total_loss = None\n        if self.model.zo_custom_eval_loss_fn:\n            total_loss = self.task_compute_function(self.model.zo_custom_eval_loss_fn,\n                inputs1=([self.model], {\"input_ids\": input_ids, \"start_logits\": start_logits, \"start_positions\": start_positions, \"end_logits\": end_logits, \"end_positions\": end_positions, **kwargs}),\n                inputs2=None,\n                compute_sync=False)\n        elif start_positions is not None and end_positions is not None:\n            total_loss = self.task_compute_function(\n                fn=get_opt_question_answering_loss,\n                inputs1={\"total_loss\": total_loss, \"start_logits\": start_logits, \"start_positions\": start_positions, \"end_logits\": end_logits, \"end_positions\": end_positions},\n                inputs2=None,\n                compute_sync=False)\n        \n        if self.model.zo_eval_loss_fn_post_hooks != []:\n            for post_hook_fn in self.model.zo_eval_loss_fn_post_hooks:\n                transformer_outputs, input_ids, start_logits, start_positions, end_logits, end_positions = self.task_compute_function(post_hook_fn,\n                    inputs1=([self.model], {\"transformer_outputs\": transformer_outputs, \"input_ids\": input_ids, \"start_logits\": start_logits, \"start_positions\": start_positions, \"end_logits\": end_logits, \"end_positions\": end_positions}),\n                    inputs2=None,\n                    compute_sync=False)\n\n        if not return_dict:\n            output = (start_logits, end_logits) + transformer_outputs[2:]\n            return ((total_loss,) + output) if total_loss is not None else output\n\n        return QuestionAnsweringModelOutput(\n            loss=total_loss,\n            start_logits=start_logits,\n            end_logits=end_logits,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n"
  },
  {
    "path": "zo2/model/huggingface/qwen3/__init__.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nfrom . import (\n    mezo_sgd,\n)\n\ndef get_qwen3_for_causalLM(zo_config):\n    zo2_supported_configs = {\n        \"mezo-sgd\": mezo_sgd.get_qwen3_for_causalLM_mezo_sgd,\n    }\n    return zo2_supported_configs[zo_config.zo_method](zo_config)\n"
  },
  {
    "path": "zo2/model/huggingface/qwen3/mezo_sgd/__init__.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nfrom . import zo, zo2\nfrom .....config.mezo_sgd import MeZOSGDConfig\n\ndef get_qwen3_for_causalLM_mezo_sgd(config: MeZOSGDConfig):\n    return zo2.Qwen3ForCausalLM if config.zo2 else zo.Qwen3ForCausalLM\n"
  },
  {
    "path": "zo2/model/huggingface/qwen3/mezo_sgd/utils.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\n\ndef fn_get_qwen3_decoder_hidden_states_from_layer_outputs(input):\n    return input[0]\n\ndef fn_get_qwen3_sliced_logits_from_hidden_states(hidden_states, slice_indices):\n    return hidden_states[:, slice_indices, :]"
  },
  {
    "path": "zo2/model/huggingface/qwen3/mezo_sgd/zo.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n\nfrom transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache\nfrom transformers.models.qwen3 import modeling_qwen3\nfrom transformers.models.qwen3.modeling_qwen3 import (\n    Qwen3Config,\n    Qwen3PreTrainedModel,\n    Qwen3RMSNorm,\n    Qwen3RotaryEmbedding,\n    Qwen3DecoderLayer,\n    CausalLMOutputWithPast,\n    BaseModelOutputWithPast,\n    KwargsForCausalLM,\n    can_return_tuple,\n    deprecate_kwarg,\n    add_start_docstrings_to_model_forward,\n    replace_return_docstrings,\n    QWEN3_INPUTS_DOCSTRING,\n    _CONFIG_FOR_DOC,\n)\nfrom transformers.utils import logging\n\nimport random\nfrom typing import List, Optional, Tuple, Union, Unpack\n\nfrom ....base import BaseZOModel\nfrom .....optimizer.mezo_sgd.zo import MeZOSGD\nfrom .....config.mezo_sgd import MeZOSGDConfig\n\nlogger = logging.get_logger(__name__)\n\n\nclass Qwen3Model(modeling_qwen3.Qwen3Model, Qwen3PreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen3DecoderLayer`]\n\n    Args:\n        config: Qwen3Config\n    \"\"\"\n\n    def __init__(self, config: Qwen3Config):\n        config.use_cache = False\n        Qwen3PreTrainedModel.__init__(self, config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n        self.rotary_emb = Qwen3RotaryEmbedding(config=config)\n        self.layers = nn.ModuleList(\n            [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]\n        )\n        self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.gradient_checkpointing = False\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n\nclass Qwen3ForCausalLM(modeling_qwen3.Qwen3ForCausalLM, Qwen3PreTrainedModel, BaseZOModel):\n    _tied_weights_keys = [\"lm_head.weight\"]\n    _tp_plan = {\"lm_head\": \"colwise_rep\"}\n    _pp_plan = {\"lm_head\": ([\"hidden_states\"], [\"logits\"])}\n\n    def __init__(self, config: Qwen3Config):\n        Qwen3PreTrainedModel.__init__(self, config)\n        BaseZOModel.__init__(self)\n        self.model = Qwen3Model(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def zo_init(self, zo_config):\n        self.opt = OptimizerQwen3ForCausalLM(model=self, config=zo_config)\n\n    @can_return_tuple\n    @deprecate_kwarg(\"num_logits_to_keep\", version=\"4.50\", new_name=\"logits_to_keep\")\n    @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING)\n    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        logits_to_keep: Union[int, torch.Tensor] = 0,\n        **kwargs: Unpack[KwargsForCausalLM],\n    ) -> CausalLMOutputWithPast:\n        r\"\"\"\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n            logits_to_keep (`int` or `torch.Tensor`, *optional*):\n                If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all\n                `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that\n                token can save memory, which becomes pretty significant for long sequences or large vocabulary size.\n                If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.\n                This is useful when using packed tensor format (single dimension for batch and sequence length).\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, Qwen3ForCausalLM\n\n        >>> model = Qwen3ForCausalLM.from_pretrained(\"Qwen/Qwen3-8B\")\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen3-8B\")\n\n        >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n        ```\"\"\"\n\n        if self.zo_training:\n            use_cache = False\n            return self.opt.zo_forward(\n                input_ids, attention_mask, position_ids, \n                past_key_values, inputs_embeds, labels, use_cache, \n                output_attentions, output_hidden_states, \n                cache_position, logits_to_keep, **kwargs)\n        else:\n            return self.opt.zo_eval_forward(super().forward, \n                input_ids, attention_mask, position_ids, \n                past_key_values, inputs_embeds, labels, use_cache, \n                output_attentions, output_hidden_states, \n                cache_position, logits_to_keep, **kwargs)\n\n\nclass OptimizerQwen3ForCausalLM(MeZOSGD):\n    \n    @torch.inference_mode\n    def inner_zo_forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        logits_to_keep: Union[int, torch.Tensor] = 0,\n        **kwargs: Unpack[KwargsForCausalLM],\n    ) -> CausalLMOutputWithPast:\n        \"\"\"\n            copy the original forward code and replace all 'self' to 'self.model'.\n        \"\"\"\n\n        output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states\n        )\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs: BaseModelOutputWithPast = self.model.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            cache_position=cache_position,\n            **kwargs,\n        )\n\n        hidden_states = outputs.last_hidden_state\n        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss\n        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep\n        logits = self.model.lm_head(hidden_states[:, slice_indices, :])\n\n        if self.model.zo_train_loss_fn_pre_hooks != []:\n            for pre_hook_fn in self.model.zo_train_loss_fn_pre_hooks:\n                input_ids, logits, labels = pre_hook_fn(self.model, input_ids, logits, labels)\n\n        loss = None\n        if labels is not None:\n            if self.model.zo_custom_train_loss_fn:\n                loss = self.model.zo_custom_train_loss_fn(self.model, input_ids, logits, labels, **kwargs)\n            else:\n                loss = self.model.loss_function(logits=logits, labels=labels, vocab_size=self.model.config.vocab_size, **kwargs)\n\n        if self.model.zo_train_loss_fn_post_hooks != []:\n            for post_hook_fn in self.model.zo_train_loss_fn_post_hooks:\n                loss, input_ids, logits, labels = post_hook_fn(self.model, loss, input_ids, logits, labels)\n\n        # add --> only return loss\n        return loss.detach()\n\n    @torch.inference_mode\n    def inner_zo_eval_forward(\n        self,\n        eval_fn,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        logits_to_keep: Union[int, torch.Tensor] = 0,\n        **kwargs: Unpack[KwargsForCausalLM],\n    ) -> CausalLMOutputWithPast:\n        if self.model.zo_eval_loss_fn_pre_hooks != []:\n            for pre_hook_fn in self.model.zo_eval_loss_fn_pre_hooks:\n                input_ids, logits, labels = pre_hook_fn(self.model, input_ids, logits, labels)\n\n        if self.model.zo_custom_eval_loss_fn:\n            output = eval_fn(input_ids, attention_mask, position_ids, \n                past_key_values, inputs_embeds, None, use_cache, \n                output_attentions, output_hidden_states, \n                cache_position, logits_to_keep, **kwargs)\n            logits = output[\"logits\"]\n            loss = None\n            if labels is not None:\n                loss = self.model.zo_custom_eval_loss_fn(self.model, input_ids, logits, labels, **kwargs)\n            output = CausalLMOutputWithPast(\n                loss=loss,\n                logits=logits,\n                past_key_values=output.past_key_values,\n                hidden_states=output.hidden_states,\n                attentions=output.attentions,\n            )\n        else:\n            output = eval_fn(input_ids, attention_mask, position_ids, \n                past_key_values, inputs_embeds, labels, use_cache, \n                output_attentions, output_hidden_states, \n                cache_position, logits_to_keep, **kwargs)\n            \n        if self.model.zo_eval_loss_fn_post_hooks != []:\n            for post_hook_fn in self.model.zo_eval_loss_fn_post_hooks:\n                output, input_ids, logits, labels = post_hook_fn(self.model, output, input_ids, logits, labels)\n        return output"
  },
  {
    "path": "zo2/model/huggingface/qwen3/mezo_sgd/zo2.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport random\nimport torch\nimport torch.nn as nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n\nfrom transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache\nfrom transformers.models.qwen3 import modeling_qwen3\nfrom transformers.models.qwen3.modeling_qwen3 import (\n    Qwen3Config,\n    Qwen3PreTrainedModel,\n    Qwen3RMSNorm,\n    Qwen3RotaryEmbedding,\n    Qwen3DecoderLayer,\n    CausalLMOutputWithPast,\n    BaseModelOutputWithPast,\n    KwargsForCausalLM,\n    FlashAttentionKwargs,\n    partial,\n    can_return_tuple,\n    deprecate_kwarg,\n    add_start_docstrings_to_model_forward,\n    replace_return_docstrings,\n    QWEN3_INPUTS_DOCSTRING,\n    _CONFIG_FOR_DOC,\n)\nfrom transformers.utils import logging\n\nfrom typing import List, Optional, Tuple, Union, Unpack\n\nfrom ....base import BaseZOModel\nfrom .....optimizer.mezo_sgd.zo2 import MeZO2SGD\nfrom .....config.mezo_sgd import MeZOSGDConfig\nfrom .utils import *\n\nlogger = logging.get_logger(__name__)\n\n\nclass Qwen3Model(modeling_qwen3.Qwen3Model, Qwen3PreTrainedModel, BaseZOModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen3DecoderLayer`]\n\n    Args:\n        config: Qwen3Config\n    \"\"\"\n    def __init__(self, config: Qwen3Config):\n        \"\"\"\n        !!! Module register must follow the execution order.\n        \"\"\"\n        config.use_cache = False\n        Qwen3PreTrainedModel.__init__(self, config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n        self.rotary_emb = Qwen3RotaryEmbedding(config=config)\n        self.layers = nn.ModuleList(\n            [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]\n        )\n        self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.gradient_checkpointing = False\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def zo_init(self, zo_config):\n        # Initialize ZO2\n        self.opt = OptimizerQwen3Model(model=self, config=zo_config)\n    \n    @can_return_tuple\n    @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        **flash_attn_kwargs: Unpack[FlashAttentionKwargs],\n    ) -> BaseModelOutputWithPast:\n        if self.zo_training:\n            return self.opt.inner_zo_forward(input_ids, attention_mask, position_ids, \n                past_key_values, inputs_embeds, use_cache, \n                output_attentions, output_hidden_states, cache_position,\n                **flash_attn_kwargs)\n        else:\n            return self.opt.zo_eval_forward(input_ids, attention_mask, position_ids, \n                past_key_values, inputs_embeds, use_cache, \n                output_attentions, output_hidden_states, cache_position,\n                **flash_attn_kwargs)\n\n\nclass Qwen3ForCausalLM(modeling_qwen3.Qwen3ForCausalLM, Qwen3PreTrainedModel, BaseZOModel):\n    _tied_weights_keys = [\"lm_head.weight\"]\n    _tp_plan = {\"lm_head\": \"colwise_rep\"}\n    _pp_plan = {\"lm_head\": ([\"hidden_states\"], [\"logits\"])}\n\n    def __init__(self, config):\n        Qwen3PreTrainedModel.__init__(self, config)\n        BaseZOModel.__init__(self)\n        self.model = Qwen3Model(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def zo_init(self, zo_config):\n        self.model.zo_init(zo_config)\n        # Initialize ZO2\n        self.opt = OptimizerQwen3ForCausalLM(model=self, config=zo_config)\n\n    @can_return_tuple\n    @deprecate_kwarg(\"num_logits_to_keep\", version=\"4.50\", new_name=\"logits_to_keep\")\n    @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING)\n    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        logits_to_keep: Union[int, torch.Tensor] = 0,\n        **kwargs: Unpack[KwargsForCausalLM],\n    ) -> CausalLMOutputWithPast:\n        r\"\"\"\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n            logits_to_keep (`int` or `torch.Tensor`, *optional*):\n                If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all\n                `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that\n                token can save memory, which becomes pretty significant for long sequences or large vocabulary size.\n                If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.\n                This is useful when using packed tensor format (single dimension for batch and sequence length).\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, Qwen3ForCausalLM\n\n        >>> model = Qwen3ForCausalLM.from_pretrained(\"Qwen/Qwen3-8B\")\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen3-8B\")\n\n        >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n        ```\"\"\"\n        if self.zo_training:\n            return self.opt.zo_forward(\n                input_ids, attention_mask, position_ids,\n                past_key_values, inputs_embeds, labels, use_cache,\n                output_attentions, output_hidden_states, cache_position,\n                logits_to_keep, **kwargs)\n        else:\n            return self.opt.zo_eval_forward(\n                input_ids, attention_mask, position_ids,\n                past_key_values, inputs_embeds, labels, use_cache,\n                output_attentions, output_hidden_states, cache_position,\n                logits_to_keep, **kwargs)\n\n\nclass OptimizerQwen3Model(MeZO2SGD):\n\n    def init_zo2(self):\n        self.upload_stream = None\n        self.offload_stream = None\n        self.compute_stream = None\n        self.zo_random_seed = None\n        self.rstate = None\n        self.rstate_queue = None\n        self.last_rstate = None\n        self.projected_grad = None\n        self.init_zo2_upload()\n    \n    def init_zo2_upload(self):\n        self.model.embed_tokens = self.model.embed_tokens.to(self.device)\n        self.model.rotary_emb = self.model.rotary_emb.to(self.device)\n        self.model.norm = self.model.norm.to(self.device)\n        self.num_blocks = len(self.model.layers)\n        if self.offloading_blocks is not None:\n            self.offloading_blocks = self.offloading_blocks\n        else:\n            self.offloading_blocks = list(range(self.num_blocks))\n        print(f\"Transformer blocks {self.offloading_blocks} will be offloaded to {self.offloading_device}\")\n        for i in range(self.num_blocks):\n            if i in self.offloading_blocks:\n                continue\n            else:\n                self.model.layers[i] = self.model.layers[i].to(self.device)\n                print(f\"Upload block {i} to {self.device}.\")\n        \n    @torch.inference_mode\n    def inner_zo_forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        **flash_attn_kwargs: Unpack[FlashAttentionKwargs],\n    ) -> BaseModelOutputWithPast:\n        output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states\n        )\n        # use_cache = use_cache if use_cache is not None else self.model.config.use_cache\n        use_cache = False\n        \n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\"You must specify exactly one of input_ids or inputs_embeds\")\n\n        if self.model.gradient_checkpointing and self.model.training and use_cache:\n            logger.warning_once(\n                \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.\"\n            )\n            use_cache = False\n\n        # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache\n        if not isinstance(past_key_values, (type(None), Cache)):\n            raise ValueError(\"The `past_key_values` should be either a `Cache` object or `None`.\")\n\n        if inputs_embeds is None:\n            # inputs_embeds = self.model.embed_tokens(input_ids)\n            inputs_embeds1, inputs_embeds2 = self.task_compute_module(\n                self.model.embed_tokens,\n                inputs1={\"input\": input_ids},\n                inputs2={\"input\": input_ids},\n                grad=self.projected_grad\n            )\n        else:\n            inputs_embeds1 = inputs_embeds2 = inputs_embeds\n\n        if use_cache and past_key_values is None:\n            past_key_values = DynamicCache()\n\n        if cache_position is None:\n            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n            cache_position = torch.arange(\n                past_seen_tokens, past_seen_tokens + inputs_embeds1.shape[1], device=inputs_embeds1.device\n            )\n\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        causal_mask1, causal_mask2 = self.task_compute_function(\n            self.model._update_causal_mask,\n            inputs1={\"attention_mask\": attention_mask, \"input_tensor\": inputs_embeds1, \"cache_position\": cache_position, \"past_key_values\": past_key_values, \"output_attentions\": output_attentions},\n            inputs2={\"attention_mask\": attention_mask, \"input_tensor\": inputs_embeds2, \"cache_position\": cache_position, \"past_key_values\": past_key_values, \"output_attentions\": output_attentions},\n            compute_sync=False,\n        )\n\n        hidden_states1, hidden_states2 = inputs_embeds1, inputs_embeds2\n\n        # create position embeddings to be shared across the decoder layers\n        position_embeddings1, position_embeddings2 = self.task_compute_module(\n            self.model.rotary_emb,\n            inputs1={\"x\": hidden_states1, \"position_ids\": position_ids},\n            inputs2={\"x\": hidden_states2, \"position_ids\": position_ids},\n            grad=self.projected_grad,\n            compute_sync=False\n        )\n\n        if 0 in self.offloading_blocks:\n            self.model.layers[0] = self.task_upload(\n                module=self.model.layers[0],\n                device=self.device\n            )\n\n        N = self.model.config.num_hidden_layers\n        for i in range(1, N):\n\n            if i != 1:\n                if i-2 in self.offloading_blocks:\n                    self.model.layers[i-2] = self.task_offload(\n                        module=self.model.layers[i-2],\n                        device=self.offloading_device)\n            \n            layer_outputs1, layer_outputs2 = self.task_compute_module(\n                self.model.layers[i-1],\n                inputs1={\"hidden_states\": hidden_states1, \"attention_mask\": causal_mask1, \"position_ids\": position_ids, \n                         \"past_key_value\": past_key_values, \"output_attentions\": output_attentions, \"use_cache\": use_cache, \n                         \"cache_position\": cache_position, \"position_embeddings\": position_embeddings1, **flash_attn_kwargs},\n                inputs2={\"hidden_states\": hidden_states2, \"attention_mask\": causal_mask2, \"position_ids\": position_ids, \n                         \"past_key_value\": past_key_values, \"output_attentions\": output_attentions, \"use_cache\": use_cache, \n                         \"cache_position\": cache_position, \"position_embeddings\": position_embeddings2, **flash_attn_kwargs},\n                grad=self.projected_grad\n            )\n\n            hidden_states1, hidden_states2 = self.task_compute_function(\n                fn=fn_get_qwen3_decoder_hidden_states_from_layer_outputs,\n                inputs1={\"input\": layer_outputs1},\n                inputs2={\"input\": layer_outputs2},\n                compute_sync=False\n            )\n\n            if i in self.offloading_blocks:\n                self.model.layers[i] = self.task_upload(\n                    module=self.model.layers[i],\n                    device=self.device)\n\n        if N-2 in self.offloading_blocks:\n            self.model.layers[N-2] = self.task_offload(\n                module=self.model.layers[N-2],\n                device=self.offloading_device)\n        \n        layer_outputs1, layer_outputs2 = self.task_compute_module(\n            self.model.layers[N-1],\n            inputs1={\"hidden_states\": hidden_states1, \"attention_mask\": causal_mask1, \"position_ids\": position_ids, \n                        \"past_key_value\": past_key_values, \"output_attentions\": output_attentions, \"use_cache\": use_cache, \n                        \"cache_position\": cache_position, \"position_embeddings\": position_embeddings1, **flash_attn_kwargs},\n            inputs2={\"hidden_states\": hidden_states2, \"attention_mask\": causal_mask2, \"position_ids\": position_ids, \n                        \"past_key_value\": past_key_values, \"output_attentions\": output_attentions, \"use_cache\": use_cache, \n                        \"cache_position\": cache_position, \"position_embeddings\": position_embeddings2, **flash_attn_kwargs},\n            grad=self.projected_grad\n        )\n\n        hidden_states1, hidden_states2 = self.task_compute_function(\n            fn=fn_get_qwen3_decoder_hidden_states_from_layer_outputs,\n            inputs1={\"input\": layer_outputs1},\n            inputs2={\"input\": layer_outputs2},\n            compute_sync=False\n        )\n\n        if N-1 in self.offloading_blocks:\n            self.model.layers[N-1] = self.task_offload(\n                module=self.model.layers[N-1],\n                device=self.offloading_device)\n            \n        hidden_states1, hidden_states2 = self.task_compute_module(\n            module=self.model.norm,\n            inputs1={\"hidden_states\": hidden_states1},\n            inputs2={\"hidden_states\": hidden_states2},\n            grad=self.projected_grad,\n            # weight_decay=0.\n        )\n\n        return hidden_states1, hidden_states2\n\n    @torch.inference_mode\n    def inner_zo_eval_forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        **flash_attn_kwargs: Unpack[FlashAttentionKwargs],\n    ) -> BaseModelOutputWithPast:\n        output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states\n        )\n        # use_cache = use_cache if use_cache is not None else self.model.config.use_cache\n        use_cache = False\n        \n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\"You must specify exactly one of input_ids or inputs_embeds\")\n\n        if self.model.gradient_checkpointing and self.model.training and use_cache:\n            logger.warning_once(\n                \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.\"\n            )\n            use_cache = False\n\n        # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache\n        if not isinstance(past_key_values, (type(None), Cache)):\n            raise ValueError(\"The `past_key_values` should be either a `Cache` object or `None`.\")\n\n        if inputs_embeds is None:\n            # inputs_embeds = self.model.embed_tokens(input_ids)\n            inputs_embeds = self.task_compute_module(\n                self.model.embed_tokens,\n                inputs1={\"input\": input_ids},\n                inputs2=None,\n                grad=None\n            )\n\n        if use_cache and past_key_values is None:\n            past_key_values = DynamicCache()\n\n        if cache_position is None:\n            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n            cache_position = torch.arange(\n                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device\n            )\n\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        causal_mask = self.task_compute_function(\n            self.model._update_causal_mask,\n            inputs1={\"attention_mask\": attention_mask, \"input_tensor\": inputs_embeds, \"cache_position\": cache_position, \"past_key_values\": past_key_values, \"output_attentions\": output_attentions},\n            inputs2=None,\n            compute_sync=False,\n        )\n\n        hidden_states = inputs_embeds\n\n        # create position embeddings to be shared across the decoder layers\n        position_embeddings = self.task_compute_module(\n            self.model.rotary_emb,\n            inputs1={\"x\": hidden_states, \"position_ids\": position_ids},\n            inputs2=None,\n            grad=None,\n            compute_sync=False\n        )\n\n        if 0 in self.offloading_blocks:\n            self.model.layers[0] = self.task_upload(\n                module=self.model.layers[0],\n                device=self.device\n            )\n\n        N = self.model.config.num_hidden_layers\n        for i in range(1, N):\n\n            if i != 1:\n                if i-2 in self.offloading_blocks:\n                    self.model.layers[i-2] = self.task_offload(\n                        module=self.model.layers[i-2],\n                        device=self.offloading_device)\n            \n            layer_outputs = self.task_compute_module(\n                self.model.layers[i-1],\n                inputs1={\"hidden_states\": hidden_states, \"attention_mask\": causal_mask, \"position_ids\": position_ids, \n                         \"past_key_value\": past_key_values, \"output_attentions\": output_attentions, \"use_cache\": use_cache, \n                         \"cache_position\": cache_position, \"position_embeddings\": position_embeddings, **flash_attn_kwargs},\n                inputs2=None,\n                grad=None\n            )\n\n            hidden_states = self.task_compute_function(\n                fn=fn_get_qwen3_decoder_hidden_states_from_layer_outputs,\n                inputs1={\"input\": layer_outputs},\n                inputs2=None,\n                compute_sync=False\n            )\n\n            if i in self.offloading_blocks:\n                self.model.layers[i] = self.task_upload(\n                    module=self.model.layers[i],\n                    device=self.device)\n\n        if N-2 in self.offloading_blocks:\n            self.model.layers[N-2] = self.task_offload(\n                module=self.model.layers[N-2],\n                device=self.offloading_device)\n        \n        layer_outputs = self.task_compute_module(\n            self.model.layers[N-1],\n            inputs1={\"hidden_states\": hidden_states, \"attention_mask\": causal_mask, \"position_ids\": position_ids,\n                      \"past_key_value\": past_key_values, \"output_attentions\": output_attentions, \"use_cache\": use_cache,\n                      \"cache_position\": cache_position, \"position_embeddings\": position_embeddings, **flash_attn_kwargs},\n            inputs2=None,\n            grad=None\n        )\n\n        hidden_states = self.task_compute_function(\n            fn=fn_get_qwen3_decoder_hidden_states_from_layer_outputs,\n            inputs1={\"input\": layer_outputs},\n            inputs2=None,\n            compute_sync=False\n        )\n\n        if N-1 in self.offloading_blocks:\n            self.model.layers[N-1] = self.task_offload(\n                module=self.model.layers[N-1],\n                device=self.offloading_device)\n\n        hidden_states = self.task_compute_module(\n            module=self.model.norm,\n            inputs1={\"hidden_states\": hidden_states},\n            inputs2=None,\n            grad=None,\n            # weight_decay=0.\n        )\n        \n        return hidden_states\n\nclass OptimizerQwen3ForCausalLM(MeZO2SGD):\n\n    def init_zo2_upload(self):\n        self.model.lm_head = self.model.lm_head.to(self.device)\n    \n    @torch.inference_mode\n    def inner_zo_forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        logits_to_keep: Union[int, torch.Tensor] = 0,\n        **kwargs: Unpack[KwargsForCausalLM],\n    ) -> CausalLMOutputWithPast:\n        output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states\n        )\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        self.model.model.zo_training = True\n        self.assign_zo2_attributes(self, self.model.model.opt)\n        hidden_states1, hidden_states2 = self.model.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            cache_position=cache_position,\n            **kwargs,\n        )\n        self.assign_zo2_attributes(self.model.model.opt, self)\n        \n        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss\n        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep\n        hidden_states1, hidden_states2 = self.task_compute_function(\n            fn_get_qwen3_sliced_logits_from_hidden_states,\n            inputs1={\"hidden_states\": hidden_states1, \"slice_indices\": slice_indices},\n            inputs2={\"hidden_states\": hidden_states2, \"slice_indices\": slice_indices},\n        )\n        logits1, logits2 = self.task_compute_module(self.model.lm_head,\n                                                    inputs1={\"input\": hidden_states1},\n                                                    inputs2={\"input\": hidden_states2},\n                                                    grad=self.projected_grad)\n\n        if self.model.zo_train_loss_fn_pre_hooks != []:\n            for pre_hook_fn in self.model.zo_train_loss_fn_pre_hooks:\n                (input_ids, logits1, labels), (input_ids, logits2, labels) = \\\n                    self.task_compute_function(pre_hook_fn,\n                        inputs1={\"self\": self.model, \"input_ids\": input_ids, \"logits\": logits1, \"labels\": labels},\n                        inputs2={\"self\": self.model, \"input_ids\": input_ids, \"logits\": logits2, \"labels\": labels})\n        \n        if labels is not None:\n            # loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)\n            if self.model.zo_custom_train_loss_fn:\n                loss1, loss2 = self.task_compute_function(self.model.zo_custom_train_loss_fn,\n                    inputs1={\"self\": self.model, \"input_ids\": input_ids, \"logits\": logits1, \"labels\": labels, **kwargs},\n                    inputs2={\"self\": self.model, \"input_ids\": input_ids, \"logits\": logits2, \"labels\": labels, **kwargs})\n            else:\n                loss1, loss2 = self.task_compute_function(\n                    self.model.loss_function,\n                    inputs1={\"logits\": logits1, \"labels\": labels, \"vocab_size\": self.model.config.vocab_size, **kwargs},\n                    inputs2={\"logits\": logits2, \"labels\": labels, \"vocab_size\": self.model.config.vocab_size, **kwargs},\n                )\n\n        if self.model.zo_train_loss_fn_post_hooks != []:\n            for post_hook_fn in self.model.zo_train_loss_fn_post_hooks:\n                (loss1, input_ids, logits1, labels), (loss2, input_ids, logits2, labels) = \\\n                    self.task_compute_function(post_hook_fn,\n                        inputs1={\"self\": self.model, \"loss\": loss1, \"input_ids\": input_ids, \"logits\": logits1, \"labels\": labels},\n                        inputs2={\"self\": self.model, \"loss\": loss2, \"input_ids\": input_ids, \"logits\": logits2, \"labels\": labels})\n\n        return loss1, loss2\n\n    @torch.inference_mode\n    def inner_zo_eval_forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        logits_to_keep: Union[int, torch.Tensor] = 0,\n        **kwargs: Unpack[KwargsForCausalLM],\n    ) -> CausalLMOutputWithPast:\n        output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states\n        )\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        self.model.model.zo_training = False\n        self.assign_zo2_attributes(self, self.model.model.opt)\n        hidden_states = self.model.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            cache_position=cache_position,\n            **kwargs,\n        )\n        self.assign_zo2_attributes(self.model.model.opt, self)\n\n        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss\n        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep\n\n        hidden_states = self.task_compute_function(\n            fn_get_qwen3_sliced_logits_from_hidden_states,\n            inputs1={\"hidden_states\": hidden_states, \"slice_indices\": slice_indices},\n            inputs2=None,\n        )\n        logits = self.task_compute_module(self.model.lm_head,\n                                        inputs1={\"input\": hidden_states},\n                                        inputs2=None,\n                                        grad=None)\n\n        if self.model.zo_eval_loss_fn_pre_hooks != []:\n            for pre_hook_fn in self.model.zo_eval_loss_fn_pre_hooks:\n                input_ids, logits, labels = \\\n                    self.task_compute_function(pre_hook_fn,\n                        inputs1=([self.model], {\"input_ids\": input_ids, \"logits\": logits, \"labels\": labels}),\n                        inputs2=None)\n        \n        if labels is not None:\n            # loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)\n            if self.model.zo_custom_eval_loss_fn:\n                loss = self.task_compute_function(self.model.zo_custom_eval_loss_fn,\n                    inputs1=([self.model], {\"input_ids\": input_ids, \"logits\": logits, \"labels\": labels, **kwargs}),\n                    inputs2=None)\n            else:\n                loss = self.task_compute_function(\n                    self.model.loss_function,\n                    inputs1={\"logits\": logits, \"labels\": labels, \"vocab_size\": self.model.config.vocab_size, **kwargs},\n                    inputs2=None\n                )\n\n        if self.model.zo_eval_loss_fn_post_hooks != []:\n            for post_hook_fn in self.model.zo_eval_loss_fn_post_hooks:\n                loss, input_ids, logits, labels = \\\n                    self.task_compute_function(post_hook_fn,\n                        inputs1=([self.model], {\"loss\": loss, \"input_ids\": input_ids, \"logits\": logits, \"labels\": labels}),\n                        inputs2=None)\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=None,\n            hidden_states=None,\n            attentions=None,\n        )"
  },
  {
    "path": "zo2/model/huggingface/zo_init.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nfrom contextlib import contextmanager\nimport torch\nimport transformers\n\nfrom . import (\n    opt,\n    # llama,\n    qwen3\n)\n\n_zo2_supported_models = {\n    transformers.OPTForCausalLM: opt.get_opt_for_causalLM,\n    transformers.OPTForSequenceClassification: opt.get_opt_for_sequence_classification,\n    transformers.OPTForQuestionAnswering: opt.get_opt_for_question_answering,\n\n    # transformers.LlamaForCausalLM: llama.get_llama_for_causalLM,\n\n    transformers.Qwen3ForCausalLM: qwen3.get_qwen3_for_causalLM,\n}\n\n@contextmanager\ndef zo_hf_init(zo_config):\n    try:\n        for orig_class, get_zo2_class in _zo2_supported_models.items():\n            if hasattr(transformers, orig_class.__name__):\n                zo2_class = get_zo2_class(zo_config)\n                setattr(transformers, orig_class.__name__, zo2_class)\n            else:\n                raise NotImplementedError(f\"Model '{orig_class.__name__}' is not supported in transformers.\")\n        yield\n    finally:\n        pass\n\ndef main():\n    # user api:\n    with zo_hf_init(zo_config):\n        from transformers import OPTForCausalLM\n        model = OPTForCausalLM.from_pretrained(...)\n        model.zo_init(zo_config)\n    print(type(model))  # should be zo2.OPTForCausalLM"
  },
  {
    "path": "zo2/model/nanogpt/__init__.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nfrom . import (\n    mezo_sgd,\n)\n\ndef get_nanogpt(zo_config):\n    zo2_supported_configs = {\n        \"mezo-sgd\": mezo_sgd.get_nanogpt_mezo_sgd,\n    }\n    return zo2_supported_configs[zo_config.zo_method](zo_config)\n"
  },
  {
    "path": "zo2/model/nanogpt/mezo_sgd/__init__.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nfrom ..model import GPTConfig, GPTConfigs, GPT\nfrom .zo import GPT as GPT_MeZOSGD\nfrom .zo2 import GPT as GPT_MeZO2SGD\nfrom ....config.mezo_sgd import MeZOSGDConfig\n\ndef get_nanogpt_mezo_sgd(config: MeZOSGDConfig):\n    return GPT_MeZO2SGD if config.zo2 else GPT_MeZOSGD"
  },
  {
    "path": "zo2/model/nanogpt/mezo_sgd/zo.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\nimport torch.nn.functional as F\n\nfrom .. import model\nfrom ...base import BaseZOModel\nfrom ....optimizer.mezo_sgd.zo import MeZOSGD\nfrom ....config.mezo_sgd import MeZOSGDConfig\n\n\nclass GPT(model.GPT, BaseZOModel):\n    def __init__(self, config: model.GPTConfig, zo_config: MeZOSGDConfig):\n        super().__init__(config)\n        self.opt = Optimizer(model=self, config=zo_config)\n\n    def forward(self, idx, pos, targets=None):\n        if self.zo_training:\n            return self.opt.zo_forward(idx, pos, targets)\n        else:\n            # for evaluate and inference purpose\n            return self.opt.zo_eval_forward(super().forward, idx, pos, targets)\n\n\nclass Optimizer(MeZOSGD):\n\n    @torch.inference_mode\n    def inner_zo_forward(self, idx, pos, targets):\n        tok_emb = self.model.transformer.wte(idx)\n        pos_emb = self.model.transformer.wpe(pos)\n        x = tok_emb + pos_emb\n        for block in self.model.transformer.h:\n            x = block(x)\n        x = self.model.transformer.ln_f(x)\n        x = self.model.lm_head(x)\n        loss = F.cross_entropy(\n            x.reshape(-1, x.size(-1)), \n            targets.reshape(-1)\n        )\n        return loss.detach()\n\n    @torch.inference_mode()   \n    def inner_zo_eval_forward(self, eval_fn, idx, pos, targets):\n        output = eval_fn(idx, pos, targets)\n        return output\n    "
  },
  {
    "path": "zo2/model/nanogpt/mezo_sgd/zo2.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\n\nfrom .. import model\nfrom ...base import BaseZOModel\nfrom ....optimizer.mezo_sgd.zo2 import MeZO2SGD\nfrom ....config.mezo_sgd import MeZOSGDConfig\n\n\nclass GPT(model.GPT, BaseZOModel):\n    def __init__(self, config: model.GPTConfig, zo_config: MeZOSGDConfig):\n        super().__init__(config)\n        self.opt = Optimizer(model=self, config=zo_config)\n\n    def forward(self, idx, pos, targets=None):\n        if self.zo_training:\n            return self.opt.zo_forward(idx, pos, targets)\n        else:\n            # for evaluate and inference purpose\n            return self.opt.zo_eval_forward(super().forward, idx, pos, targets)\n\n\nclass Optimizer(MeZO2SGD):\n    \n    def init_zo2_upload(self):\n        print(\"Upload head and tail to cuda.\")\n        self.model.transformer.wte = self.model.transformer.wte.to(self.device)\n        self.model.transformer.wpe = self.model.transformer.wpe.to(self.device)\n        self.model.transformer.ln_f = self.model.transformer.ln_f.to(self.device)\n        self.model.lm_head = self.model.lm_head.to(self.device)\n        \n        self.num_blocks = len(self.model.transformer.h)\n        if self.offloading_blocks is not None:\n            self.offloading_blocks = self.offloading_blocks\n        else:\n            self.offloading_blocks = list(range(self.num_blocks))\n        print(f\"Transformer blocks {self.offloading_blocks} will be offloaded to {self.offloading_device}\")\n        for i in range(self.num_blocks):\n            if i in self.offloading_blocks:\n                continue\n            else:\n                self.model.transformer.h[i] = self.model.transformer.h[i].to(self.device)\n                print(f\"Upload block {i} to cuda.\")\n\n    @torch.inference_mode()   \n    def inner_zo_forward(self, idx, pos, targets):\n        we1, we2 = self.task_compute_module(self.model.transformer.wte,\n                                inputs1={\"input\": idx},\n                                inputs2={\"input\": idx},\n                                grad=self.projected_grad)\n        pe1, pe2 = self.task_compute_module(self.model.transformer.wpe, \n                                 {\"input\": pos}, \n                                 {\"input\": pos}, \n                                 self.projected_grad,\n                                 compute_sync=False)\n        hidden_states1, hidden_states2 = self.task_compute_function(torch.add,\n                                                                    {\"input\": we1, \"other\": pe1},\n                                                                    {\"input\": we2, \"other\": pe2},\n                                                                    compute_sync=False)\n        if 0 in self.offloading_blocks:\n            self.model.transformer.h[0] = self.task_upload(\n                module=self.model.transformer.h[0], \n                device=self.device)\n        N = len(self.model.transformer.h)\n        for i in range(1, N):\n            if i != 1:\n                if i-2 in self.offloading_blocks:\n                    self.model.transformer.h[i-2] = self.task_offload(\n                        module=self.model.transformer.h[i-2], \n                        device=self.offloading_device)\n            hidden_states1, hidden_states2 = self.task_compute_module(\n                self.model.transformer.h[i-1], \n                inputs1={\"x\": hidden_states1}, \n                inputs2={\"x\": hidden_states2}, \n                grad=self.projected_grad)\n            if i in self.offloading_blocks:\n                self.model.transformer.h[i] = self.task_upload(\n                    module=self.model.transformer.h[i], \n                    device=self.device)\n        if N-2 in self.offloading_blocks:\n            self.model.transformer.h[N-2] = self.task_offload(\n                self.model.transformer.h[N-2], device=self.offloading_device)\n        hidden_states1, hidden_states2 = self.task_compute_module(\n                    self.model.transformer.h[N-1], \n                    inputs1={\"x\": hidden_states1}, \n                    inputs2={\"x\": hidden_states2}, \n                    grad=self.projected_grad\n                )\n        if N-1 in self.offloading_blocks:\n            self.model.transformer.h[N-1] = self.task_offload(\n                self.model.transformer.h[N-1], device=self.offloading_device)\n        logits1, logits2 = self.task_compute_module(self.model.transformer.ln_f,\n                                             inputs1={\"input\": hidden_states1}, \n                                             inputs2={\"input\": hidden_states2}, \n                                             grad=self.projected_grad,\n                                             weight_decay=0.)   \n                    # 'task_compute_module' will remove the first name 'ln_f', so we need to disable weight_decay manually.\n        logits1, logits2 = self.task_compute_module(self.model.lm_head,\n                                             inputs1={\"input\": logits1}, \n                                             inputs2={\"input\": logits2}, \n                                             grad=self.projected_grad)\n        loss1, loss2 = self.task_compute_function(F.cross_entropy,\n                                                  {\"input\": logits1.reshape(-1, logits1.size(-1)), \n                                                   \"target\": targets.reshape(-1)},\n                                                  {\"input\": logits2.reshape(-1, logits2.size(-1)), \n                                                   \"target\": targets.reshape(-1)})\n        return loss1, loss2\n    \n    @torch.inference_mode()   \n    def inner_zo_eval_forward(self, eval_fn, idx, pos, targets):\n        handles = self.add_zo2_eval_comm_hooks(self.model.transformer.h)\n        output = eval_fn(idx, pos, targets)\n        self.clear_zo2_eval_comm_hooks(handles)\n        return output\n    "
  },
  {
    "path": "zo2/model/nanogpt/model.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n\"\"\"\nModified from https://github.com/karpathy/build-nanogpt/blob/master/train_gpt2.py\n\"\"\"\nimport sys\nsys.path.append(\"./zo2\")\n\nimport math\nimport inspect\nfrom dataclasses import dataclass\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\n@dataclass\nclass GPTConfig:\n    block_size: int = 1024\n    vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency\n    n_layer: int = 12\n    n_head: int = 12\n    n_embd: int = 768\n    dropout: float = 0.0\n    bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster\n\nclass GPTConfigs:\n    gpt2: GPTConfig = GPTConfig(n_layer=12, n_head=12, n_embd=768)\n    gpt2_medium: GPTConfig = GPTConfig(n_layer=24, n_head=16, n_embd=1024)\n    gpt2_large: GPTConfig = GPTConfig(n_layer=36, n_head=20, n_embd=1280)\n    gpt2_xl: GPTConfig = GPTConfig(n_layer=48, n_head=25, n_embd=1600)\n    opt_125m: GPTConfig = GPTConfig(n_layer=12, n_head=12, n_embd=768, block_size=2048)\n    opt_350m: GPTConfig = GPTConfig(n_layer=24, n_head=16, n_embd=1024, block_size=2048)\n    opt_1_3b: GPTConfig = GPTConfig(n_layer=24, n_head=32, n_embd=2048, block_size=2048)\n    opt_2_7b: GPTConfig = GPTConfig(n_layer=32, n_head=32, n_embd=2560, block_size=2048)\n    opt_6_7b: GPTConfig = GPTConfig(n_layer=32, n_head=32, n_embd=4096, block_size=2048)\n    opt_13b: GPTConfig = GPTConfig(n_layer=40, n_head=40, n_embd=5120, block_size=2048)\n    opt_30b: GPTConfig = GPTConfig(n_layer=48, n_head=56, n_embd=7168, block_size=2048)\n    opt_66b: GPTConfig = GPTConfig(n_layer=64, n_head=72, n_embd=9216, block_size=2048)\n    opt_175b: GPTConfig = GPTConfig(n_layer=96, n_head=96, n_embd=12288, block_size=2048)\n\nclass LayerNorm(nn.Module):\n    \"\"\" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False \"\"\"\n\n    def __init__(self, ndim, bias):\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(ndim))\n        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None\n\n    def forward(self, input):\n        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)\n\nclass CausalSelfAttention(nn.Module):\n\n    def __init__(self, config):\n        super().__init__()\n        assert config.n_embd % config.n_head == 0\n        # key, query, value projections for all heads, but in a batch\n        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)\n        # output projection\n        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)\n        # regularization\n        self.attn_dropout = nn.Dropout(config.dropout)\n        self.resid_dropout = nn.Dropout(config.dropout)\n        self.n_head = config.n_head\n        self.n_embd = config.n_embd\n        self.dropout = config.dropout\n        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0\n        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')\n        if not self.flash:\n            print(\"WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0\")\n            # causal mask to ensure that attention is only applied to the left in the input sequence\n            self.register_buffer(\"bias\", torch.tril(torch.ones(config.block_size, config.block_size))\n                                        .view(1, 1, config.block_size, config.block_size))\n\n    def forward(self, x):\n        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)\n\n        # calculate query, key, values for all heads in batch and move head forward to be the batch dim\n        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)\n        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)\n        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)\n        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)\n\n        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)\n        if self.flash:\n            # efficient attention using Flash Attention CUDA kernels\n            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)\n        else:\n            # manual implementation of attention\n            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))\n            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))\n            att = F.softmax(att, dim=-1)\n            att = self.attn_dropout(att)\n            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)\n        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side\n\n        # output projection\n        y = self.resid_dropout(self.c_proj(y))\n        return y\n\nclass MLP(nn.Module):\n\n    def __init__(self, config):\n        super().__init__()\n        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)\n        self.gelu    = nn.GELU()\n        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)\n        self.dropout = nn.Dropout(config.dropout)\n\n    def forward(self, x):\n        x = self.c_fc(x)\n        x = self.gelu(x)\n        x = self.c_proj(x)\n        x = self.dropout(x)\n        return x\n\nclass Block(nn.Module):\n\n    def __init__(self, config):\n        super().__init__()\n        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)\n        self.attn = CausalSelfAttention(config)\n        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)\n        self.mlp = MLP(config)\n\n    def forward(self, x):\n        x = x + self.attn(self.ln_1(x))\n        x = x + self.mlp(self.ln_2(x))\n        return x\n\n\nclass GPT(nn.Module):\n\n    def __init__(self, config):\n        super().__init__()\n        assert config.vocab_size is not None\n        assert config.block_size is not None\n        self.config = config\n\n        self.transformer = nn.ModuleDict(dict(\n            wte = nn.Embedding(config.vocab_size, config.n_embd),\n            wpe = nn.Embedding(config.block_size, config.n_embd),\n            drop = nn.Dropout(config.dropout),\n            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),\n            ln_f = LayerNorm(config.n_embd, bias=config.bias),\n        ))\n        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n        # with weight tying when using torch.compile() some warnings get generated:\n        # \"UserWarning: functional_call was passed multiple values for tied weights.\n        # This behavior is deprecated and will be an error in future versions\"\n        # not 100% sure what this is, so far seems to be harmless. TODO investigate\n        # self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying\n\n        # init all weights\n        self.apply(self._init_weights)\n        # apply special scaled init to the residual projections, per GPT-2 paper\n        for pn, p in self.named_parameters():\n            if pn.endswith('c_proj.weight'):\n                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))\n\n        # report number of parameters\n        print(\"number of parameters: %.2fM\" % (self.get_num_params()/1e6,))\n\n    def get_num_params(self, non_embedding=True):\n        \"\"\"\n        Return the number of parameters in the model.\n        For non-embedding count (default), the position embeddings get subtracted.\n        The token embeddings would too, except due to the parameter sharing these\n        params are actually used as weights in the final layer, so we include them.\n        \"\"\"\n        n_params = sum(p.numel() for p in self.parameters())\n        if non_embedding:\n            n_params -= self.transformer.wpe.weight.numel()\n        return n_params\n\n    def _init_weights(self, module):\n        if isinstance(module, nn.Linear):\n            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n            if module.bias is not None:\n                torch.nn.init.zeros_(module.bias)\n        elif isinstance(module, nn.Embedding):\n            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n\n    def forward(self, idx, pos, targets=None):\n        # idx is of shape (B, T)\n        B, T = idx.size()\n        assert T <= self.config.block_size, f\"Cannot forward sequence of length {T}, block size is only {self.config.block_size}\"\n        # forward the token and posisition embeddings\n        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)\n        tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)\n        x = tok_emb + pos_emb\n        # forward the blocks of the transformer\n        for block in self.transformer.h:\n            x = block(x)\n        # forward the final layernorm and the classifier\n        x = self.transformer.ln_f(x)\n        logits = self.lm_head(x) # (B, T, vocab_size)\n        loss = None\n        if targets is not None:\n            loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))\n        return logits, loss\n\n    def crop_block_size(self, block_size):\n        # model surgery to decrease the block size if necessary\n        # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)\n        # but want to use a smaller block size for some smaller, simpler model\n        assert block_size <= self.config.block_size\n        self.config.block_size = block_size\n        self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])\n        for block in self.transformer.h:\n            if hasattr(block.attn, 'bias'):\n                block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]\n\n    @classmethod\n    def from_pretrained(cls, model_type, override_args=None):\n        assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}\n        override_args = override_args or {} # default to empty dict\n        # only dropout can be overridden see more notes below\n        assert all(k == 'dropout' for k in override_args)\n        from transformers import GPT2LMHeadModel\n        print(\"loading weights from pretrained gpt: %s\" % model_type)\n\n        # n_layer, n_head and n_embd are determined from model_type\n        config_args = {\n            'gpt2':         dict(n_layer=12, n_head=12, n_embd=768),  # 124M params\n            'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params\n            'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params\n            'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params\n        }[model_type]\n        print(\"forcing vocab_size=50257, block_size=1024, bias=True\")\n        config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints\n        config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints\n        config_args['bias'] = True # always True for GPT model checkpoints\n        # we can override the dropout rate, if desired\n        if 'dropout' in override_args:\n            print(f\"overriding dropout rate to {override_args['dropout']}\")\n            config_args['dropout'] = override_args['dropout']\n        # create a from-scratch initialized minGPT model\n        config = GPTConfig(**config_args)\n        model = GPT(config)\n        sd = model.state_dict()\n        sd_keys = sd.keys()\n        sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param\n\n        # init a huggingface/transformers model\n        model_hf = GPT2LMHeadModel.from_pretrained(model_type)\n        sd_hf = model_hf.state_dict()\n\n        # copy while ensuring all of the parameters are aligned and match in names and shapes\n        sd_keys_hf = sd_hf.keys()\n        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer\n        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)\n        transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']\n        # basically the openai checkpoints use a \"Conv1D\" module, but we only want to use a vanilla Linear\n        # this means that we have to transpose these weights when we import them\n        assert len(sd_keys_hf) == len(sd_keys), f\"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}\"\n        for k in sd_keys_hf:\n            if any(k.endswith(w) for w in transposed):\n                # special treatment for the Conv1D weights we need to transpose\n                assert sd_hf[k].shape[::-1] == sd[k].shape\n                with torch.no_grad():\n                    sd[k].copy_(sd_hf[k].t())\n            else:\n                # vanilla copy over the other parameters\n                assert sd_hf[k].shape == sd[k].shape\n                with torch.no_grad():\n                    sd[k].copy_(sd_hf[k])\n\n        return model\n\n    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):\n        # start with all of the candidate parameters\n        param_dict = {pn: p for pn, p in self.named_parameters()}\n        # filter out those that do not require grad\n        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}\n        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.\n        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.\n        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]\n        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]\n        optim_groups = [\n            {'params': decay_params, 'weight_decay': weight_decay},\n            {'params': nodecay_params, 'weight_decay': 0.0}\n        ]\n        num_decay_params = sum(p.numel() for p in decay_params)\n        num_nodecay_params = sum(p.numel() for p in nodecay_params)\n        print(f\"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters\")\n        print(f\"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters\")\n        # Create AdamW optimizer and use the fused version if it is available\n        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters\n        use_fused = fused_available and device_type == 'cuda'\n        extra_args = dict(fused=True) if use_fused else dict()\n        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)\n        print(f\"using fused AdamW: {use_fused}\")\n\n        return optimizer\n\n    def estimate_mfu(self, fwdbwd_per_iter, dt):\n        \"\"\" estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS \"\"\"\n        # first estimate the number of flops we do per iteration.\n        # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311\n        N = self.get_num_params()\n        cfg = self.config\n        L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size\n        flops_per_token = 6*N + 12*L*H*Q*T\n        flops_per_fwdbwd = flops_per_token * T\n        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter\n        flops_achieved = flops_per_iter * (1.0/dt) # per second\n        flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS\n        mfu = flops_achieved / flops_promised\n        return mfu\n\n    @torch.no_grad()\n    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):\n        \"\"\"\n        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete\n        the sequence max_new_tokens times, feeding the predictions back into the model each time.\n        Most likely you'll want to make sure to be in model.eval() mode of operation for this.\n        \"\"\"\n        for _ in range(max_new_tokens):\n            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]\n            logits, _ = self(idx_cond)\n            logits = logits[:, -1, :] / temperature\n            if top_k is not None:\n                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))\n                logits[logits < v[:, [-1]]] = -float('Inf')\n            probs = F.softmax(logits, dim=-1)\n            idx_next = torch.multinomial(probs, num_samples=1)\n            idx = torch.cat((idx, idx_next), dim=1)\n\n        return idx"
  },
  {
    "path": "zo2/optimizer/__init__.py",
    "content": ""
  },
  {
    "path": "zo2/optimizer/base.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport torch\nfrom torch.optim.optimizer import Optimizer\n\nclass BaseOptimizer(Optimizer):    \n    \"\"\"\n    Base class for Zeroth-Order Optimization handling basic setup, including learning rate management.\n    This class is not intended for direct use but provides core functionalities for derived classes.\n    \"\"\"\n    def __init__(self, params, defaults):\n        \"\"\"\n        Initializes the BaseOptimizer.\n\n        Args:\n            params (iterable): Parameters to optimize or dicts defining parameter groups.\n            defaults (dict): Default optimization options.\n        \"\"\"\n        super().__init__(params, defaults)\n        self.lr = defaults[\"lr\"]\n        if len(self.param_groups) > 1:\n            raise NotImplementedError(\"Currently ZO2 does not support multi-group optimizing.\")\n    \n    def _update_lr(self):\n        self.lr = self.param_groups[0][\"lr\"]\n    \n    def _set_lr(self):\n        self.param_groups[0][\"lr\"] = self.lr"
  },
  {
    "path": "zo2/optimizer/mezo_sgd/__init__.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n"
  },
  {
    "path": "zo2/optimizer/mezo_sgd/utils/__init__.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nfrom .com import *\nfrom .comm import *"
  },
  {
    "path": "zo2/optimizer/mezo_sgd/utils/com.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\n"
  },
  {
    "path": "zo2/optimizer/mezo_sgd/utils/comm.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport os\nimport torch\nimport torch.nn as nn\n\n\ndef module_to_bucket_inplace(module: nn.Module):\n    bucket = torch.cat([p.view(-1) for p in module.parameters()])\n    return bucket\n\ndef bucket_to_module_inplace(bucket: torch.Tensor, module: nn.Module):\n    offset = 0\n    for name, param in module.named_parameters():\n        num_elements = param.numel()\n        new_param = bucket[offset: offset+num_elements].view_as(param)\n        set_nested_attr(module, name, nn.Parameter(new_param, requires_grad=param.requires_grad))\n        offset += num_elements\n    return module\n\n\ndef create_disk_offload_path(path, module_id):\n    if os.path.isfile(path):\n        raise ValueError(\"'path' must be a dir.\")\n    elif os.path.isdir(path):\n        file_path = os.path.join(path, module_id, 'tmp.pt')\n        if not os.path.exists(path):\n            os.makedirs(path)\n    else:\n        os.makedirs(path)\n        file_path = os.path.join(path, module_id, 'tmp.pt')\n    return file_path\n\ndef get_disk_offload_path(path, module_id):\n    return os.path.join(path, module_id, 'tmp.pt')\n\ndef clear_disk_offload_path(path, module_id):\n    disk_offload_path = os.path.join(path, module_id)\n    if os.path.isdir(disk_offload_path):\n        if not os.listdir(disk_offload_path):\n            os.rmdir(disk_offload_path)\n\n\n\ndef set_nested_attr(obj, attr, value):\n    attrs = attr.split('.')\n    for attr in attrs[:-1]:\n        obj = getattr(obj, attr)\n    setattr(obj, attrs[-1], value)\n"
  },
  {
    "path": "zo2/optimizer/mezo_sgd/zo.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport sys\nsys.path.append('./zo2')\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom ..base import BaseOptimizer\nimport numpy as np\n\nfrom ...config.mezo_sgd import MeZOSGDConfig\n\n\nclass MeZOSGD(BaseOptimizer):\n    \"\"\"\n    Implements the [MeZO-SGD](https://arxiv.org/abs/2305.17333) optimization method, \n    particularly suited for scenarios with limited compute resources.\n    \"\"\"\n    def __init__(self, model: nn.Module, config: MeZOSGDConfig):\n        \"\"\"\n        Initializes the MeZOSGD optimizer which applies zeroth-order optimization techniques to the model parameters.\n\n        Args:\n            model (nn.Module): The model whose parameters will be optimized.\n            config (MeZOSGDConfig): Configuration object containing optimizer settings.\n        \"\"\"\n        self.config = config\n        self.model = model\n        self.lr = config.lr\n        self.weight_decay = config.weight_decay\n        self.zo_eps = config.eps\n        self.max_zo_random_seed = config.max_zo_random_seed\n        self.debug_mode = config.debug_mode\n        defaults = dict(\n            lr=self.lr,\n            weight_decay=self.weight_decay,\n            maximize=False,\n            foreach=None,\n            differentiable=False,\n            fused=None,\n        )\n        super().__init__(model.parameters(), defaults)\n        \n    @torch.inference_mode\n    def zo_perturb_parameters(self, module: nn.Module, scaling_factor: float=1):\n        \"\"\"\n        Applies Gaussian noise to parameters of a module, facilitating zeroth-order optimization.\n\n        Args:\n            module (nn.Module): Module whose parameters will be perturbed.\n            scaling_factor (float): Scaling factor for the noise applied to the parameters.\n        \"\"\"\n        for _, param in module.named_parameters():\n            if param.requires_grad:\n                # Resample z\n                if self.debug_mode:\n                    z = torch.ones_like(param.data) # for debug\n                else:\n                    z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)\n                param.data.add_(scaling_factor * z * self.zo_eps)\n\n    @torch.inference_mode\n    def zo_update(self, module, weight_decay=None):\n        \"\"\"\n        Updates the parameters of a module based on zeroth-order perturbations and optional weight decay.\n\n        Args:\n            module (nn.Module): Module whose parameters will be updated.\n            weight_decay (float, optional): Weight decay coefficient. If None, it defaults to the configuration.\n        \"\"\"\n        for name, param in module.named_parameters():\n            if param.requires_grad:\n                # Resample z\n                if self.debug_mode:\n                    z = torch.ones_like(param.data) # for debug\n                else:\n                    z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)\n                if weight_decay != None:\n                    param.data.sub_(\n                        self.lr * (self.projected_grad * z + weight_decay * param.data))\n                else:\n                    if all(x not in name for x in [\"bias\", \"layer_norm\", \"layernorm\", \"ln\"]):\n                        param.data.sub_(\n                            self.lr * (self.projected_grad * z + self.weight_decay * param.data))\n                    else:\n                        param.data.sub_(self.lr * self.projected_grad * z)\n    \n    def zo_perturb_shifts(self, first_perturb_shift=1, stride=2):\n        \"\"\"\n        Generates shifts for perturbing parameters in a pattern conducive to zeroth-order optimization.\n\n        Returns:\n            list: A list of perturb shifts used during the forward and update passes.\n        \"\"\"\n        return [first_perturb_shift, -stride, stride-first_perturb_shift]\n\n    def compute_grad(self, loss1, loss2):\n        return ((loss1 - loss2) / (2 * self.zo_eps)).item()\n        \n    @torch.inference_mode\n    def zo_forward(self, *args, zo_random_seed: int=None, **kwargs):\n        \"\"\"\n        Forward pass that applies zeroth-order perturbations to compute the loss, used for gradient estimation.\n        Notice that the application of Gaussian perturbations for the parameters during both the perturbation and update phases should be the same.\n\n        Args:\n            zo_random_seed (int, optional): Random seed for reproducibility of perturbations.\n        \"\"\"\n        self._update_lr()\n        self.zo_random_seed = zo_random_seed if zo_random_seed else np.random.randint(self.max_zo_random_seed)\n        torch.manual_seed(self.zo_random_seed)\n        self.zo_perturb_parameters(self.model, scaling_factor=self.zo_perturb_shifts()[0])\n        loss1 = self.inner_zo_forward(*args, **kwargs)\n        torch.manual_seed(self.zo_random_seed)\n        self.zo_perturb_parameters(self.model, scaling_factor=self.zo_perturb_shifts()[1])\n        loss2 = self.inner_zo_forward(*args, **kwargs)\n        self.projected_grad = self.compute_grad(loss1, loss2)\n        torch.manual_seed(self.zo_random_seed)\n        self.zo_perturb_parameters(self.model, scaling_factor=self.zo_perturb_shifts()[2])\n        torch.manual_seed(self.zo_random_seed)\n        self.zo_update(self.model)\n        return loss1\n\n    #*********************** evaluate ***********************#\n\n    @torch.inference_mode()\n    def zo_eval_forward(self, *args, **kwargs):\n        \"\"\"\n        Forward pass in evaluation mode.\n        \"\"\"\n        output = self.inner_zo_eval_forward(*args, **kwargs)\n        return output\n    \n    #*********************** api ***********************#\n\n    @torch.inference_mode\n    def inner_zo_forward(self, idx, pos, targets):\n        \"\"\"\n        Example of ZO inner_zo_forward:\n            Match the same args as the original model forward,\n            and replace all 'self' to 'self.model'.\n        \"\"\"\n        tok_emb = self.model.transformer.wte(idx)\n        pos_emb = self.model.transformer.wpe(pos)\n        x = tok_emb + pos_emb\n        for block in self.model.transformer.h:\n            x = block(x)\n        x = self.model.transformer.ln_f(x)\n        x = self.model.lm_head(x)\n        loss = F.cross_entropy(\n            x[:, :-1, :].reshape(-1, x.size(-1)), \n            targets[:, 1:].reshape(-1)\n        )\n        return loss.detach()\n\n    @torch.inference_mode()   \n    def inner_zo_eval_forward(self, eval_fn, idx, pos, targets):\n        output = eval_fn(idx, pos, targets)\n        return output\n    \n"
  },
  {
    "path": "zo2/optimizer/mezo_sgd/zo2.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nimport sys\nsys.path.append('./zo2')\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom collections import deque\n\nfrom .zo import MeZOSGD\nfrom ...config.mezo_sgd import MeZOSGDConfig\nfrom .utils import *\n\n\nclass MeZO2SGD(MeZOSGD):\n    first_call_eval = True  # Class variable specifically for tracking eval function\n    \n    \"\"\"\n    Extends MeZOSGD to support advanced offloading techniques that enhance the capability\n    to train large models on systems with limited GPU memory. It manages the intricate\n    balance between CPU and GPU, leveraging zeroth-order optimization with dynamic memory\n    management through offloading.\n    \"\"\"\n    def __init__(self, model, config: MeZOSGDConfig):\n        \"\"\"\n        Initializes the MeZO2SGD optimizer, setting up the necessary configuration for\n        offloading and optimization techniques.\n\n        Args:\n            model (nn.Module): The model whose parameters will be optimized.\n            config (MeZOSGDConfig): Configuration object specifying optimizer settings including\n                                    offloading and overlapping options.\n        \"\"\"\n        assert config.zo2, \"MeZO2SGD can only work with offloading.\"\n        super().__init__(model, config)\n        self.device = config.working_device\n        self.offloading_device = config.offloading_device\n        self.overlap = config.overlap\n        self.offloading_blocks = config.offloading_blocks\n        self.compute_module_optimize_method = config.compute_module_optimize_method\n        self.compute_function_optimize_method = config.compute_function_optimize_method\n        self.communicate_optimize_method = config.communicate_optimize_method\n        self.amp = config.amp\n        self.amp_precision = config.amp_precision\n        self.precision_on_offloading_device = config.precision_on_offloading_device\n        self.precision_on_working_device = config.precision_on_working_device\n        self.amp_compress_method = config.amp_compress_method\n        self.init_zo2()\n    \n    def init_zo2(self):\n        \"\"\"\n        Sets up CUDA streams and initializes the offloading and uploading mechanisms\n        required for efficient computation management across devices.\n        \"\"\"\n        self.upload_stream = torch.cuda.Stream()\n        self.offload_stream = torch.cuda.Stream()\n        self.compute_stream = torch.cuda.Stream()\n        self.zo_random_seed = None\n        self.rstate = None\n        self.rstate_queue = deque(maxlen=2)\n        self.last_rstate = None\n        self.projected_grad = 0\n        self.init_zo2_upload()\n        if self.amp: self.init_zo2_amp()\n    \n    def init_zo2_amp(self):\n        \"\"\"\n        Initializes the model parameters to use different precision levels based on their current device.\n        This method works with Automatic Mixed Precision (AMP) by setting the precision for parameters \n        based on whether they are located on the working device or the offloading device.\n        \"\"\"\n        working_device = torch.device(self.device)\n        offloading_device = torch.device(self.offloading_device)\n        for p in self.model.parameters():\n            if p.device == working_device:\n                p.data = p.data.to(dtype=self.precision_on_working_device)\n            elif p.device == offloading_device:\n                p.data = p.data.to(dtype=self.precision_on_offloading_device)\n            else:\n                raise ValueError(f\"Unsupported device found for parameter: {p.device}\")\n\n    def assign_zo2_attributes(self, source, target):\n        \"\"\"\n        Utility function to transfer ZO2 specific attributes from one module to another,\n        aiding in maintaining consistency across nested model architectures.\n\n        Args:\n            source: The source module from which attributes are copied.\n            target: The target module to which attributes are assigned.\n        \"\"\"\n        attrs_to_assign = ['upload_stream', 'offload_stream', 'compute_stream', \n                           'zo_random_seed', 'rstate', 'rstate_queue', 'last_rstate', \n                           'projected_grad']\n        for attr in attrs_to_assign:\n            setattr(target, attr, getattr(source, attr))\n    \n    @torch.inference_mode\n    def zo_update(self, module, weight_decay=None):\n        \"\"\"\n        Applies the computed gradients to update parameters of the module, potentially\n        including a weight decay term. This method is enhanced by managing CUDA state\n        to ensure consistent random number generation across calls.\n\n        Args:\n            module (nn.Module): The module whose parameters are to be updated.\n            weight_decay (float, optional): Optional weight decay for regularization.\n        \"\"\"\n        torch.cuda.set_rng_state(self.last_rstate)\n        super().zo_update(module, weight_decay=weight_decay)\n        self.last_rstate = torch.cuda.get_rng_state()\n        return module\n    \n    @torch.inference_mode()\n    def module_dual_forward(self, module, inputs1, inputs2, projected_grad=0., weight_decay=None):\n        \"\"\"\n        Performs two parallel forward computations with perturbed parameters to estimate\n        gradients. This function is key for zeroth-order gradient estimation with support\n        for optional weight decay during parameter update. \n        \n        Notice that the application of Gaussian perturbations for the parameters \n        during both the perturbation and update phases should be the same.\n\n        Args:\n            module (nn.Module): The module on which forward passes are conducted.\n            inputs1 (dict): Inputs for the first forward pass.\n            inputs2 (dict): Inputs for the second forward pass.\n            projected_grad (float): Projected gradient value used for updating parameters.\n            weight_decay (float, optional): Optional weight decay for regularization.\n        \"\"\"\n        if projected_grad != 0:\n            module = self.zo_update(module, weight_decay)\n        torch.cuda.set_rng_state(self.rstate)\n        self.zo_perturb_parameters(module, scaling_factor=self.zo_perturb_shifts()[0])\n        output1 = module(**inputs1)\n        torch.cuda.set_rng_state(self.rstate)\n        self.zo_perturb_parameters(module, scaling_factor=self.zo_perturb_shifts()[1])\n        output2 = module(**inputs2)\n        torch.cuda.set_rng_state(self.rstate)\n        self.zo_perturb_parameters(module, scaling_factor=self.zo_perturb_shifts()[2])\n        self.rstate = torch.cuda.get_rng_state()\n        return output1, output2\n    \n    @torch.inference_mode()\n    def function_dual_forward(self, fn, inputs1, inputs2):\n        \"\"\"\n        Executes a provided function twice with dual inputs, supporting the zeroth-order optimization process\n        by enabling the estimation of gradients through function outputs.\n\n        Args:\n            fn (callable): The function to be executed.\n            inputs1 (dict): Arguments for the first execution of the function.\n            inputs2 (dict): Arguments for the second execution of the function.\n\n        Returns:\n            tuple: Outputs from the two executions of the function.\n        \"\"\"\n        output1 = fn(**inputs1)\n        output2 = fn(**inputs2)\n        return output1, output2\n    \n    @torch.inference_mode()\n    def zo_forward(self, *args, seed: int=None, **kwargs):\n        \"\"\"\n        The overarching forward function that integrates perturbation, gradient estimation,\n        and parameter update within a single coherent process, controlled by the seed for reproducibility.\n\n        Args:\n            seed (int, optional): Seed for random number generation to ensure reproducibility.\n        \"\"\"\n        self._update_lr()\n        self.zo_random_seed = seed if seed else np.random.randint(self.max_zo_random_seed)\n        torch.manual_seed(self.zo_random_seed)\n        torch.cuda.manual_seed(self.zo_random_seed)\n        self.rstate = torch.cuda.get_rng_state()\n        self.rstate_queue.append(self.rstate.clone())\n        if len(self.rstate_queue) == 2:\n            self.last_rstate = self.rstate_queue.popleft()\n        torch.cuda.synchronize()    # global sync to make sure all tasks finish\n        loss1, loss2 = self.inner_zo_forward(*args, **kwargs)\n        torch.cuda.synchronize()    # global sync to make sure all tasks finish\n        self.projected_grad = self.compute_grad(loss1, loss2)\n        return loss1.detach()\n    \n    #*********************** tasks ***********************#\n\n    def task_upload(self, module, device='cuda', upload_sync=False, *args, **kwargs):\n        \"\"\"\n        Handles the uploading of modules to the GPU, utilizing CUDA streams to potentially overlap\n        computation and communication for efficiency.\n\n        Args:\n            module (nn.Module): Module to be uploaded.\n            device (str): Target device for the upload.\n            upload_sync (bool): Whether to synchronize the upload stream before proceeding.\n        \"\"\"\n        if self.overlap:\n            if upload_sync:\n                self.upload_stream.synchronize()\n        with torch.cuda.stream(self.upload_stream if self.overlap else torch.cuda.current_stream()):\n            module = self.upload_impl(\n                module, \n                device, \n                self.offloading_device,\n                self.communicate_optimize_method, \n                non_blocking=self.overlap, \n                *args, **kwargs\n            )\n        return module\n\n    def task_offload(self, module, device='cpu', offload_sync=False, *args, **kwargs):\n        \"\"\"\n        Manages the offloading of modules to an alternative storage (e.g., CPU or disk), using CUDA streams\n        to manage dependencies and potentially overlap tasks.\n\n        Args:\n            module (nn.Module): Module to be offloaded.\n            device (str): Target device for the offload.\n            offload_sync (bool): Whether to synchronize the offload stream before proceeding.\n        \"\"\"\n        if self.overlap:\n            if offload_sync:\n                self.offload_stream.synchronize()\n            self.compute_stream.synchronize()   # offload depends on compute task\n        with torch.cuda.stream(self.offload_stream if self.overlap else torch.cuda.current_stream()):\n            module = self.offload_impl(\n                module, \n                device, \n                self.offloading_device,\n                self.communicate_optimize_method, \n                non_blocking=self.overlap, \n                *args, **kwargs\n            )\n        return module\n    \n    def task_compute_module(self, module, inputs1, inputs2, grad, compute_sync=False, weight_decay=None, *args, **kwargs):\n        \"\"\"\n        Conducts computations on a module with optional dual inputs for gradient estimation,\n        applying synchronization and CUDA streams for efficiency.\n\n        Args:\n            module (nn.Module): The module on which computations are to be performed.\n            inputs1 (dict): Inputs for the first computation.\n            inputs2 (dict, could be None): Inputs for the second computation, if performing dual forward.\n            grad (float): Gradient value to be applied.\n            compute_sync (bool): Whether to synchronize the compute stream before proceeding.\n            weight_decay (float, optional): Optional weight decay during the update.\n        \"\"\"\n        if self.overlap:\n            if compute_sync:\n                self.compute_stream.synchronize()\n            self.upload_stream.synchronize()   # module compute depends on upload task\n        with torch.cuda.stream(self.compute_stream if self.overlap else torch.cuda.current_stream()):\n            if inputs2 is not None:\n                return self.compute_module_impl(\n                    self.module_dual_forward,\n                    module,\n                    self.compute_module_optimize_method,\n                    inputs1=inputs1, \n                    inputs2=inputs2,\n                    projected_grad=grad,\n                    weight_decay=weight_decay,\n                    *args, **kwargs\n                )\n            elif isinstance(inputs1, list):\n                return self.compute_module_impl(\n                    None,\n                    module,\n                    self.compute_module_optimize_method,\n                    *inputs1,\n                    *args,\n                    **kwargs\n                )\n            elif isinstance(inputs1, dict):\n                return self.compute_module_impl(\n                    None,\n                    module,\n                    self.compute_module_optimize_method,\n                    *args,\n                    **inputs1,\n                    **kwargs\n                )\n            elif isinstance(inputs1, tuple):\n                return self.compute_module_impl(\n                    None,\n                    module,\n                    self.compute_module_optimize_method,\n                    *inputs1[0],\n                    *args,\n                    **inputs1[1],\n                    **kwargs\n                )\n            else:\n                raise ValueError(\"Invalid inputs type.\")\n    \n    def task_compute_function(self, fn, inputs1, inputs2, compute_sync=False, *args, **kwargs):\n        \"\"\"\n        Executes a provided function with dual input sets to facilitate parallel operations\n        and gradient estimation. This method integrates CUDA streams for efficient task execution.\n\n        Args:\n            fn (callable): The function to execute, typically a PyTorch operation or custom function.\n            inputs1 (dict): Arguments for the first execution of the function.\n            inputs2 (dict, could be None): Arguments for the second execution of the function.\n            compute_sync (bool): Whether to synchronize the compute stream before execution to ensure data readiness.\n        \"\"\"\n        if self.overlap:\n            if compute_sync:\n                self.compute_stream.synchronize()\n        with torch.cuda.stream(self.compute_stream if self.overlap else torch.cuda.current_stream()):\n            if inputs2 is not None:\n                return self.compute_function_impl(\n                    self.function_dual_forward,\n                    fn,\n                    self.compute_function_optimize_method,\n                    inputs1=inputs1, \n                    inputs2=inputs2,\n                    *args, **kwargs\n                )\n            elif isinstance(inputs1, list):\n                return self.compute_function_impl(\n                    None,\n                    fn, \n                    self.compute_function_optimize_method,\n                    *inputs1,\n                    *args,\n                    **kwargs\n                )\n            elif isinstance(inputs1, dict):\n                return self.compute_function_impl(\n                    None,\n                    fn, \n                    self.compute_function_optimize_method,\n                    *args,\n                    **inputs1,\n                    **kwargs\n                )\n            elif isinstance(inputs1, tuple):\n                return self.compute_function_impl(\n                    None,\n                    fn, \n                    self.compute_function_optimize_method,\n                    *inputs1[0],\n                    *args,\n                    **inputs1[1],\n                    **kwargs\n                )\n            else:\n                raise ValueError(\"Invalid inputs type.\")\n\n    #*********************** evaluate ***********************#\n\n    @torch.inference_mode()\n    def zo_eval_forward(self, *args, **kwargs):\n        \"\"\"\n        Conducts a model evaluation using the internal forward method without applying any perturbations.\n        This method ensures all tasks finish before and after the evaluation to maintain synchronization.\n\n        Args:\n            *args, **kwargs: Arguments and keyword arguments for the model's forward method.\n        \"\"\"\n        if MeZO2SGD.first_call_eval:\n            print(\"Warning: ZO2 may not efficiently optimize the evaluation stage, which could result in slower performance.\")\n            MeZO2SGD.first_call_eval = False  # Disable the warning after the first call\n        torch.cuda.synchronize()    # global sync to make sure all tasks finish\n        output = self.inner_zo_eval_forward(*args, **kwargs)\n        torch.cuda.synchronize()    # global sync to make sure all tasks finish\n        return output\n    \n    def add_zo2_eval_comm_hooks(self, blocks):\n        \"\"\"\n        Attaches communication hooks to model blocks to manage data uploading and offloading during evaluation.\n        This helps in managing memory more efficiently during the eval phase.\n\n        Args:\n            blocks (list): List of model blocks to attach hooks to.\n\n        Returns:\n            list: A list of hook handles for managing lifecycle.\n        \"\"\"\n        handles = []\n        for block in blocks:\n            if isinstance(block, nn.Module):\n                pre_handle = block.register_forward_pre_hook(self.eval_upload_hook)\n                post_handle = block.register_forward_hook(self.eval_offload_hook)\n                handles.append(pre_handle)\n                handles.append(post_handle)\n        return handles\n    \n    def clear_zo2_eval_comm_hooks(self, handles):\n        \"\"\"\n        Removes communication hooks from model blocks after evaluation to clean up and prevent memory leaks.\n\n        Args:\n            handles (list): List of hook handles to be removed.\n        \"\"\"\n        for handle in handles:\n            handle.remove()\n    \n    def eval_upload_hook(self, module, input):\n        \"\"\"\n        A forward pre-hook to upload a module to the GPU before its evaluation.\n\n        Args:\n            module (nn.Module): Module to be uploaded.\n            input: Input data for the module.\n        \"\"\"\n        self.upload_impl(\n            module, \n            self.device, \n            self.offloading_device\n        )\n        return input\n\n    def eval_offload_hook(self, module, input, output):\n        \"\"\"\n        A forward hook to offload a module from the GPU after its evaluation to free up memory.\n\n        Args:\n            module (nn.Module): Module to be offloaded.\n            input: Input data for the module.\n            output: Output from the module evaluation.\n        \"\"\"\n        if self.overlap:\n            with torch.cuda.stream(self.offload_stream):\n                self.offload_impl(\n                    module, \n                    self.offloading_device, \n                    self.offloading_device\n                )\n        else:\n            self.offload_impl(\n                module, \n                self.offloading_device, \n                self.offloading_device\n            )\n        return output\n    \n    #*********************** backend ***********************#\n\n    def upload_impl(\n            self,\n            module: nn.Module, \n            device: str, \n            offloading_device: str,\n            optimize_method: str = \"\", \n            module_id: str = None,\n            *args, **kwargs\n        ):\n        \"\"\"\n        Implements the logic for uploading model components to a specified device.\n        Supports various optimization methods to tailor the upload process for different computing environments.\n        \"\"\"\n        def _upload_impl(module, device, offloading_device, *args, **kwargs):\n            if offloading_device == \"cpu\":\n                module = module.to(device, *args, **kwargs)\n            else:\n                if module_id == None:\n                    raise ValueError(\"For disk offloading mode, 'module_id' cannot be None.\")\n                offloading_disk_path = get_disk_offload_path(offloading_device, module_id)\n                match type(module):\n                    case torch.Tensor:\n                        module = torch.load(offloading_disk_path, map_location=device)\n                    case nn.Module:\n                        module.load_state_dict(torch.load(offloading_disk_path, map_location=device))\n                    case _:\n                        raise ValueError\n                clear_disk_offload_path(offloading_device, module_id)\n            return module\n        match optimize_method:\n            case \"\":\n                module = _upload_impl(module, device, offloading_device, *args, **kwargs)\n            case \"bucket\":  # works on large-scale models\n                bucket = module_to_bucket_inplace(module)\n                bucket = _upload_impl(bucket, device, offloading_device, *args, **kwargs)\n                module = bucket_to_module_inplace(bucket, module)\n            case _:\n                raise NotImplementedError\n        if self.amp:    # after uploading, decompress the module to higher precision\n            module = self.amp_decompress_impl(module)\n        return module\n\n    def offload_impl(\n            self,\n            module: nn.Module, \n            device: str, \n            offloading_device: str,\n            optimize_method: str = \"\", \n            module_id: str = None,\n            *args, **kwargs\n        ):\n        \"\"\"\n        Implements the logic for offloading model components from the GPU to another storage,\n        such as CPU or disk, to manage GPU memory more efficiently.\n        \"\"\"\n        def _offload_impl(module, device, offloading_device, *args, **kwargs):\n            if offloading_device == \"cpu\":\n                module = module.to(device, *args, **kwargs)\n            else:\n                if module_id == None:\n                    raise ValueError(\"For disk offloading mode, 'module_id' cannot be None.\")\n                offloading_disk_path = create_disk_offload_path(offloading_device, module_id)\n                match type(module):\n                    case torch.Tensor:\n                        torch.save(module, offloading_disk_path)\n                    case nn.Module:\n                        torch.save(module.state_dict(), offloading_disk_path)\n                    case _:\n                        raise ValueError\n            return module\n        if self.amp:    # before offloading, compress the module to lower precision\n            module = self.amp_compress_impl(module)\n        match optimize_method:\n            case \"\":\n                module = _offload_impl(module, device, offloading_device, *args, **kwargs)\n            case \"bucket\":  # works on large-scale models\n                bucket = module_to_bucket_inplace(module)\n                bucket = _offload_impl(bucket, device, offloading_device, *args, **kwargs)\n                module = bucket_to_module_inplace(bucket, module)\n            case _:\n                raise NotImplementedError\n        return module\n        \n    def compute_module_impl(\n            self,\n            forward_fn,\n            module: torch.nn.Module,\n            optimize_method: str,\n            *args, \n            optimize_kwargs = None,\n            **kwargs\n        ):\n        \"\"\"\n        Manages the computation tasks on a module, applying various optimization methods\n        to enhance execution speed and efficiency.\n        \"\"\"\n        match optimize_method:\n            case \"\":\n                pass\n            case \"torch.compile\":   # may introduce some precision mismatch\n                module = torch.compile(module, **optimize_kwargs)\n            case _:\n                raise NotImplementedError\n        with torch.autocast(device_type=self.device, dtype=self.amp_precision, enabled=self.amp):\n            if forward_fn is None:\n                return module(*args, **kwargs)\n            else:\n                return forward_fn(module=module, *args, **kwargs)\n\n    def compute_function_impl(\n            self,\n            function_fn,\n            fn,\n            optimize_method: str,\n            *args, \n            optimize_kwargs = None,\n            **kwargs\n        ):\n        \"\"\"\n        Manages the computation tasks on a function, applying various optimization methods\n        to enhance function execution speed and efficiency.\n        \"\"\"\n        match optimize_method:\n            case \"\":\n                pass\n            case \"torch.jit.script\":   # may introduce some precision mismatch\n                fn = torch.jit.script(fn, **optimize_kwargs)\n            case _:\n                raise NotImplementedError\n        with torch.autocast(device_type=self.device, dtype=self.amp_precision, enabled=self.amp):\n            if function_fn is None:\n                return fn(*args, **kwargs)\n            else:\n                return function_fn(fn, *args, **kwargs)\n\n    def amp_decompress_impl(self, module: nn.Module) -> nn.Module:\n        \"\"\"\n        Converts the data type of module parameters to a higher precision typically used for computations.\n        This is part of the AMP process where parameters might be temporarily compressed to a lower precision\n        and need to be decompressed back to higher precision for accuracy-critical operations.\n\n        Args:\n            module (nn.Module): The module whose parameters will be decompressed.\n\n        Returns:\n            nn.Module: The module with parameters converted to higher precision.\n        \"\"\"\n        for p in module.parameters():\n            match self.amp_compress_method:\n                case \"naive\":\n                    p.data = p.data.to(dtype=self.precision_on_working_device)\n                case _:\n                    raise NotImplementedError\n        return module\n\n    def amp_compress_impl(self, module: nn.Module) -> nn.Module:\n        \"\"\"\n        Compresses the data type of module parameters to a lower precision typically used to save memory and \n        improve computational efficiency during less accuracy-critical operations.\n        \n        Args:\n            module (nn.Module): The module whose parameters will be compressed.\n\n        Returns:\n            nn.Module: The module with parameters converted to lower precision.\n        \"\"\"\n        for p in module.parameters():\n            match self.amp_compress_method:\n                case \"naive\":\n                    p.data = p.data.to(dtype=self.precision_on_offloading_device)\n                case _:\n                    raise NotImplementedError\n        return module\n\n    #*********************** api ***********************#\n\n    def init_zo2_upload(self):\n        \"\"\"\n        Initializes the upload of essential model components to the GPU.\n        This method specifically handles the uploading of model embeddings and head components,\n        and prepares the offloading blocks based on configuration. This setup is crucial for\n        managing the active memory footprint during training by selectively uploading and\n        offloading transformer blocks as needed.\n        \"\"\"\n        print(\"Upload head and tail to cuda.\")\n        self.model.transformer.wte = self.model.transformer.wte.to(self.device)\n        self.model.transformer.wpe = self.model.transformer.wpe.to(self.device)\n        self.model.transformer.ln_f = self.model.transformer.ln_f.to(self.device)\n        self.model.lm_head = self.model.lm_head.to(self.device)\n\n        self.num_blocks = len(self.model.transformer.h)\n        if self.offloading_blocks is not None:\n            self.offloading_blocks = self.offloading_blocks\n        else:\n            self.offloading_blocks = list(range(self.num_blocks))\n        print(f\"Transformer blocks {self.offloading_blocks} will be offloaded to {self.offloading_device}\")\n        for i in range(self.num_blocks):\n            if i in self.offloading_blocks:\n                continue\n            else:\n                self.model.transformer.h[i] = self.model.transformer.h[i].to(self.device)\n                print(f\"Upload block {i} to cuda.\")\n    \n    @torch.inference_mode()   \n    def inner_zo_forward(self, idx, pos, targets):\n        \"\"\"\n        Defines the inner forward logic for zeroth-order optimization, applying perturbations\n        and calculating the loss for gradient estimation. This method, using nanogpt as an example, orchestrates the forward\n        computation across potentially offloaded transformer blocks, ensuring they are uploaded\n        for computation and offloaded post-computation as configured.\n\n        Args:\n            idx (Tensor): Input indices for token embeddings.\n            pos (Tensor): Position indices for positional embeddings.\n            targets (Tensor): Target outputs for loss calculation.\n\n        Returns:\n            Tuple[Tensor, Tensor]: The losses computed from two perturbed forward passes, used for gradient estimation.\n        \"\"\"\n        we1, we2 = self.task_compute_module(self.model.transformer.wte,\n                                inputs1={\"input\": idx},\n                                inputs2={\"input\": idx},\n                                grad=self.projected_grad)\n        pe1, pe2 = self.task_compute_module(self.model.transformer.wpe, \n                                 {\"input\": pos}, \n                                 {\"input\": pos}, \n                                 self.projected_grad)\n        hidden_states1, hidden_states2 = self.task_compute_function(torch.add,\n                                                                    {\"input\": we1, \"other\": pe1},\n                                                                    {\"input\": we2, \"other\": pe2})\n        if 0 in self.offloading_blocks:\n            self.model.transformer.h[0] = self.task_upload(\n                module=self.model.transformer.h[0], \n                device=self.device)\n        N = len(self.model.transformer.h)\n        for i in range(1, N):\n            if i != 1:\n                if i-2 in self.offloading_blocks:\n                    self.model.transformer.h[i-2] = self.task_offload(\n                        module=self.model.transformer.h[i-2], \n                        device=self.offloading_device)\n            hidden_states1, hidden_states2 = self.task_compute_module(\n                self.model.transformer.h[i-1], \n                inputs1={\"x\": hidden_states1}, \n                inputs2={\"x\": hidden_states2}, \n                grad=self.projected_grad)\n            if i in self.offloading_blocks:\n                self.model.transformer.h[i] = self.task_upload(\n                    module=self.model.transformer.h[i], \n                    device=self.device)\n        if N-2 in self.offloading_blocks:\n            self.model.transformer.h[N-2] = self.task_offload(\n                self.model.transformer.h[N-2], device=self.offloading_device)\n        hidden_states1, hidden_states2 = self.task_compute_module(\n                    self.model.transformer.h[N-1], \n                    inputs1={\"x\": hidden_states1}, \n                    inputs2={\"x\": hidden_states2}, \n                    grad=self.projected_grad\n                )\n        if N-1 in self.offloading_blocks:\n            self.model.transformer.h[N-1] = self.task_offload(\n                self.model.transformer.h[N-1], device=self.offloading_device)\n        logits1, logits2 = self.task_compute_module(self.model.transformer.ln_f,\n                                             inputs1={\"input\": hidden_states1}, \n                                             inputs2={\"input\": hidden_states2}, \n                                             grad=self.projected_grad,\n                                             weight_decay=0.)\n        logits1, logits2 = self.task_compute_module(self.model.lm_head,\n                                             inputs1={\"input\": logits1}, \n                                             inputs2={\"input\": logits2}, \n                                             grad=self.projected_grad)\n        loss1, loss2 = self.task_compute_function(F.cross_entropy,\n                                                  {\"input\": logits1[:, :-1, :].reshape(-1, logits1.size(-1)), \n                                                   \"target\": targets[:, 1:].reshape(-1)},\n                                                  {\"input\": logits2[:, :-1, :].reshape(-1, logits2.size(-1)), \n                                                   \"target\": targets[:, 1:].reshape(-1)})\n        return loss1, loss2\n    \n    @torch.inference_mode()   \n    def inner_zo_eval_forward(self, eval_fn, idx, pos, targets):\n        \"\"\"\n        Conducts an evaluation forward pass of the model using the zeroth-order optimization setup,\n        but without applying any perturbations to ensure accurate performance assessment.\n        This function manages the dynamic uploading and offloading of transformer blocks as needed,\n        utilizing pre- and post-hooks to optimize memory usage during evaluation.\n\n        Args:\n            eval_fn (callable): The evaluation function to be applied, typically involves a forward pass\n                                that computes the loss or other metrics without updating model parameters.\n            idx (Tensor): Input indices for token embeddings.\n            pos (Tensor): Position indices for positional embeddings.\n            targets (Tensor): Target outputs for computing the evaluation metric (e.g., loss).\n\n        Returns:\n            Tensor: The output from the evaluation function, typically loss or accuracy metrics.\n        \"\"\"\n        handles = self.add_zo2_eval_comm_hooks(self.model.transformer.h)\n        output = eval_fn(idx, pos, targets)\n        self.clear_zo2_eval_comm_hooks(handles)\n        return output\n    "
  },
  {
    "path": "zo2/trainer/__init__.py",
    "content": ""
  },
  {
    "path": "zo2/trainer/hf_transformers/__init__.py",
    "content": "from .trainer import ZOTrainer"
  },
  {
    "path": "zo2/trainer/hf_transformers/trainer.py",
    "content": "# Copyright 2020-present the HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.\n\"\"\"\n\nimport contextlib\nimport copy\nimport functools\nimport glob\nimport importlib.metadata\nimport inspect\nimport json\nimport math\nimport os\nimport random\nimport re\nimport shutil\nimport sys\nimport tempfile\nimport time\nimport warnings\nfrom collections.abc import Mapping\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING, Any, Callable, Optional, Union\n\n\n# Integrations must be imported before ML frameworks:\n# isort: off\nfrom transformers.integrations import (\n    get_reporting_integration_callbacks,\n)\n\n# isort: on\n\nimport huggingface_hub.utils as hf_hub_utils\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nfrom huggingface_hub import ModelCard, create_repo, upload_folder\nfrom packaging import version\nfrom torch import nn\nfrom torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler\n\nfrom transformers import Trainer\nfrom transformers import __version__\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator\nfrom transformers.debug_utils import DebugOption, DebugUnderflowOverflow\nfrom transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor\nfrom transformers.feature_extraction_utils import FeatureExtractionMixin\nfrom transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend\nfrom transformers.image_processing_utils import BaseImageProcessor\nfrom transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available\nfrom transformers.integrations.tpu import tpu_spmd_dataloader\nfrom transformers.modelcard import TrainingSummary\nfrom transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model\nfrom transformers.models.auto.modeling_auto import (\n    MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,\n    MODEL_MAPPING_NAMES,\n)\nfrom transformers.optimization import Adafactor, get_scheduler\nfrom transformers.processing_utils import ProcessorMixin\nfrom transformers.pytorch_utils import (\n    ALL_LAYERNORM_LAYERS,\n    is_torch_greater_or_equal_than_2_3,\n)\nfrom transformers.tokenization_utils_base import PreTrainedTokenizerBase\nfrom transformers.trainer_callback import (\n    CallbackHandler,\n    DefaultFlowCallback,\n    ExportableState,\n    PrinterCallback,\n    ProgressCallback,\n    TrainerCallback,\n    TrainerControl,\n    TrainerState,\n)\nfrom transformers.trainer_pt_utils import (\n    DistributedTensorGatherer,\n    EvalLoopContainer,\n    IterableDatasetShard,\n    LabelSmoother,\n    LayerWiseDummyOptimizer,\n    LengthGroupedSampler,\n    SequentialDistributedSampler,\n    distributed_broadcast_scalars,\n    distributed_concat,\n    find_batch_size,\n    get_model_param_count,\n    get_module_class_from_name,\n    get_parameter_names,\n    nested_concat,\n    nested_detach,\n    nested_numpify,\n    nested_xla_mesh_reduce,\n    reissue_pt_warnings,\n    remove_dummy_checkpoint,\n    set_rng_state_for_device,\n)\nfrom transformers.trainer_utils import (\n    PREFIX_CHECKPOINT_DIR,\n    BestRun,\n    EvalLoopOutput,\n    EvalPrediction,\n    HPSearchBackend,\n    HubStrategy,\n    PredictionOutput,\n    RemoveColumnsCollator,\n    SaveStrategy,\n    TrainerMemoryTracker,\n    TrainOutput,\n    check_target_module_exists,\n    default_compute_objective,\n    denumpify_detensorize,\n    enable_full_determinism,\n    find_executable_batch_size,\n    get_last_checkpoint,\n    has_length,\n    neftune_post_forward_hook,\n    number_of_arguments,\n    seed_worker,\n    set_seed,\n    speed_metrics,\n)\nfrom transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments\nfrom transformers.utils import (\n    ADAPTER_CONFIG_NAME,\n    ADAPTER_SAFE_WEIGHTS_NAME,\n    ADAPTER_WEIGHTS_NAME,\n    CONFIG_NAME,\n    SAFE_WEIGHTS_INDEX_NAME,\n    SAFE_WEIGHTS_NAME,\n    WEIGHTS_INDEX_NAME,\n    WEIGHTS_NAME,\n    XLA_FSDPV2_MIN_VERSION,\n    PushInProgress,\n    PushToHubMixin,\n    can_return_loss,\n    find_labels,\n    is_accelerate_available,\n    is_apex_available,\n    is_apollo_torch_available,\n    is_bitsandbytes_available,\n    is_datasets_available,\n    is_galore_torch_available,\n    is_grokadamw_available,\n    is_in_notebook,\n    is_ipex_available,\n    is_liger_kernel_available,\n    is_lomo_available,\n    is_peft_available,\n    is_safetensors_available,\n    is_sagemaker_dp_enabled,\n    is_sagemaker_mp_enabled,\n    is_schedulefree_available,\n    is_torch_compile_available,\n    is_torch_hpu_available,\n    is_torch_mlu_available,\n    is_torch_mps_available,\n    is_torch_musa_available,\n    is_torch_neuroncore_available,\n    is_torch_npu_available,\n    is_torch_xla_available,\n    is_torch_xpu_available,\n    is_torchao_available,\n    logging,\n    strtobool,\n)\nfrom transformers.utils.deprecation import deprecate_kwarg\nfrom transformers.utils.quantization_config import QuantizationMethod\n\n\nDEFAULT_CALLBACKS = [DefaultFlowCallback]\nDEFAULT_PROGRESS_CALLBACK = ProgressCallback\n\nif is_in_notebook():\n    from transformers.utils.notebook import NotebookProgressCallback\n\n    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback\n\nif is_apex_available():\n    from apex import amp\n\nif is_datasets_available():\n    import datasets\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n    import torch_xla.debug.metrics as met\n    from torch_xla import __version__ as XLA_VERSION\n\n    IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION)\n    if IS_XLA_FSDPV2_POST_2_2:\n        import torch_xla.distributed.spmd as xs\n        import torch_xla.runtime as xr\nelse:\n    IS_XLA_FSDPV2_POST_2_2 = False\n\n\nif is_sagemaker_mp_enabled():\n    import smdistributed.modelparallel.torch as smp\n    from smdistributed.modelparallel import __version__ as SMP_VERSION\n\n    IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse(\"1.10\")\n\n    from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat\nelse:\n    IS_SAGEMAKER_MP_POST_1_10 = False\n\n\nif is_safetensors_available():\n    import safetensors.torch\n\nif is_peft_available():\n    from peft import PeftModel\n\n\nif is_accelerate_available():\n    from accelerate import Accelerator, skip_first_batches\n    from accelerate import __version__ as accelerate_version\n    from accelerate.state import AcceleratorState\n    from accelerate.utils import (\n        AutocastKwargs,\n        DistributedDataParallelKwargs,\n        DistributedType,\n        load_fsdp_model,\n        load_fsdp_optimizer,\n        save_fsdp_model,\n        save_fsdp_optimizer,\n    )\n\n    DATA_SAMPLERS = [RandomSampler]\n    if version.parse(accelerate_version) > version.parse(\"1.3.0\"):\n        from accelerate.utils import TorchTensorParallelPlugin\n    if version.parse(accelerate_version) > version.parse(\"0.23.0\"):\n        from accelerate.data_loader import SeedableRandomSampler\n\n        DATA_SAMPLERS += [SeedableRandomSampler]\n\n    if is_deepspeed_available():\n        from accelerate.utils import DeepSpeedSchedulerWrapper\n\nif is_accelerate_available(\"0.28.0\"):\n    from accelerate.utils import DataLoaderConfiguration\n\n\ndef _is_peft_model(model):\n    if is_peft_available():\n        classes_to_check = (PeftModel,) if is_peft_available() else ()\n        # Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321\n        if version.parse(importlib.metadata.version(\"peft\")) >= version.parse(\"0.7.0\"):\n            from peft import PeftMixedModel\n\n            classes_to_check = (*classes_to_check, PeftMixedModel)\n        return isinstance(model, classes_to_check)\n    return False\n\n\ndef _get_fsdp_ckpt_kwargs():\n    # TODO: @AjayP13, @younesbelkada replace this check with version check at the next `accelerate` release\n    if is_accelerate_available() and \"adapter_only\" in list(inspect.signature(save_fsdp_model).parameters):\n        return {\"adapter_only\": True}\n    else:\n        return {}\n\n\ndef safe_globals():\n    # Starting from version 2.4 PyTorch introduces a check for the objects loaded\n    # with torch.load(weights_only=True). Starting from 2.6 weights_only=True becomes\n    # a default and requires allowlisting of objects being loaded.\n    # See: https://github.com/pytorch/pytorch/pull/137602\n    # See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals\n    # See: https://github.com/huggingface/accelerate/pull/3036\n    if version.parse(torch.__version__).release < version.parse(\"2.6\").release:\n        return contextlib.nullcontext()\n\n    np_core = np._core if version.parse(np.__version__) >= version.parse(\"2.0.0\") else np.core\n    allowlist = [np_core.multiarray._reconstruct, np.ndarray, np.dtype]\n    # numpy >1.25 defines numpy.dtypes.UInt32DType, but below works for\n    # all versions of numpy\n    allowlist += [type(np.dtype(np.uint32))]\n\n    return torch.serialization.safe_globals(allowlist)\n\n\nif TYPE_CHECKING:\n    import optuna\n\n    if is_datasets_available():\n        import datasets\n\nlogger = logging.get_logger(__name__)\n\n\n# Name of the files used for checkpointing\nTRAINING_ARGS_NAME = \"training_args.bin\"\nTRAINER_STATE_NAME = \"trainer_state.json\"\nOPTIMIZER_NAME = \"optimizer.pt\"\nSCALER_NAME = \"scaler.pt\"\nOPTIMIZER_NAME_BIN = \"optimizer.bin\"\nSCHEDULER_NAME = \"scheduler.pt\"\nFSDP_MODEL_NAME = \"pytorch_model_fsdp\"\n\n\nclass ZOTrainer(Trainer):\n\n    # Those are used as methods of the Trainer in examples.\n    from transformers.trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state\n\n    def __init__(\n        self,\n        model: Union[PreTrainedModel, nn.Module, None] = None,\n        args: TrainingArguments = None,\n        data_collator: Optional[DataCollator] = None,\n        train_dataset: Optional[Union[Dataset, IterableDataset, \"datasets.Dataset\"]] = None,\n        eval_dataset: Optional[Union[Dataset, dict[str, Dataset], \"datasets.Dataset\"]] = None,\n        processing_class: Optional[\n            Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]\n        ] = None,\n        model_init: Optional[Callable[[], PreTrainedModel]] = None,\n        compute_loss_func: Optional[Callable] = None,\n        compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,\n        callbacks: Optional[list[TrainerCallback]] = None,\n        optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),\n        optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None,\n        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,\n    ):\n        super().__init__(model, args, data_collator, train_dataset, eval_dataset, processing_class, model_init, compute_loss_func, compute_metrics, callbacks, optimizers, optimizer_cls_and_kwargs, preprocess_logits_for_metrics)\n        \n        # ZO2 added: if using ZO2:\n        if hasattr(model, \"zo_training\"):\n            print(\"ZO training mode is enabled.\")\n            self.zo = True\n        else:\n            self.zo = False\n\n        # ZO2 added: currently unsupported conditions\n        if self.zo:\n            self._zo2_unsupported_conditions(args)\n        \n        # ZO2 added: init hooks buffer\n        if self.zo:\n            self.zo2_training_step_pre_hooks = []\n            self.zo2_training_step_post_hooks = []\n\n\n    def _inner_training_loop(\n        self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None\n    ):\n        self.accelerator.free_memory()\n        self._train_batch_size = batch_size\n        if self.args.auto_find_batch_size:\n            if self.state.train_batch_size != self._train_batch_size:\n                from accelerate.utils import release_memory\n\n                (self.model_wrapped,) = release_memory(self.model_wrapped)\n                self.model_wrapped = self.model\n\n                # Check for DeepSpeed *after* the initial pass and modify the config\n                if self.is_deepspeed_enabled:\n                    # Temporarily unset `self.args.train_batch_size`\n                    original_bs = self.args.per_device_train_batch_size\n                    self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu)\n                    self.propagate_args_to_deepspeed(True)\n                    self.args.per_device_train_batch_size = original_bs\n            self.state.train_batch_size = self._train_batch_size\n        logger.debug(f\"Currently training with a batch size of: {self._train_batch_size}\")\n        # Data loader and number of training steps\n        train_dataloader = self.get_train_dataloader()\n        if self.is_fsdp_xla_v2_enabled:\n            train_dataloader = tpu_spmd_dataloader(train_dataloader)\n\n        # Setting up training control variables:\n        # number of training epochs: num_train_epochs\n        # number of training steps per epoch: num_update_steps_per_epoch\n        # total number of training steps to execute: max_steps\n        total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size\n        (\n            num_train_epochs,\n            num_update_steps_per_epoch,\n            num_examples,\n            num_train_samples,\n            epoch_based,\n            len_dataloader,\n            max_steps,\n        ) = self.set_initial_training_values(args, train_dataloader, total_train_batch_size)\n\n        num_train_tokens = None\n        if self.args.include_tokens_per_second:\n            num_train_tokens = self.num_tokens(train_dataloader, None if epoch_based else max_steps)\n            # If going by epochs, multiply tokens linearly\n            if len_dataloader is not None and epoch_based:\n                num_train_tokens *= args.num_train_epochs\n            # Otherwise since its steps, we just multiply by grad accum\n            else:\n                num_train_tokens *= args.gradient_accumulation_steps\n\n        if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:\n            if self.args.n_gpu > 1:\n                # nn.DataParallel(model) replicates the model, creating new variables and module\n                # references registered here no longer work on other gpus, breaking the module\n                raise ValueError(\n                    \"Currently --debug underflow_overflow is not supported under DP. Please use DDP\"\n                    \" (torchrun or torch.distributed.launch (deprecated)).\"\n                )\n            else:\n                debug_overflow = DebugUnderflowOverflow(self.model)  # noqa\n\n        delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled\n\n        # Can't delay optimizer creation when using FSDP2: https://github.com/huggingface/accelerate/blob/3f636d626063ffcf9a337c7d3624d61b7d187d59/src/accelerate/accelerator.py#L1404\n        is_fsdp2 = self.is_fsdp_enabled and (getattr(self.accelerator.state.fsdp_plugin, \"fsdp_version\", 1) == 2)\n        if is_fsdp2:\n            delay_optimizer_creation = False\n\n        # We need to reset the scheduler, as its parameters may be different on subsequent calls\n        if self._created_lr_scheduler:\n            self.lr_scheduler = None\n            self._created_lr_scheduler = False\n\n        if self.is_deepspeed_enabled:\n            self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)\n\n        if not delay_optimizer_creation:\n            self.create_optimizer_and_scheduler(num_training_steps=max_steps)\n\n        self.state = TrainerState(\n            stateful_callbacks=[\n                cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)\n            ]\n        )\n        self.state.is_hyper_param_search = trial is not None\n        self.state.train_batch_size = self._train_batch_size\n\n        # Compute absolute values for logging, eval, and save if given as ratio\n        self.state.compute_steps(args, max_steps)\n\n        # Activate gradient checkpointing if needed\n        if args.gradient_checkpointing:\n            self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs)\n\n        # ZO2 added ->\n        # model = self._wrap_model(self.model_wrapped)\n        model = self.model\n\n        # as the model is wrapped, don't use `accelerator.prepare`\n        # this is for unhandled cases such as\n        # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX\n        use_accelerator_prepare = True if model is self.model else False\n\n        # ZO2 added ->\n        use_accelerator_prepare = False\n\n        if use_accelerator_prepare and self.is_fsdp_enabled:\n            # In case of auto_find_batch_size=True\n            # Remove FSDP wrapping from sub-models.\n            self.model = unwrap_model(self.model, recursive=True)\n\n        if delay_optimizer_creation:\n            if use_accelerator_prepare:\n                # configure fsdp plugin for qlora if any\n                self._fsdp_qlora_plugin_updates()\n                if self.accelerator.mixed_precision != \"fp8\":\n                    self.model = self.accelerator.prepare(self.model)\n            self.create_optimizer_and_scheduler(num_training_steps=max_steps)\n\n        # prepare using `accelerator` prepare\n        if use_accelerator_prepare:\n            self.model.train()\n            if hasattr(self.lr_scheduler, \"step\"):\n                if self.use_apex:\n                    model = self.accelerator.prepare(self.model)\n                else:\n                    model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)\n            else:\n                # to handle cases wherein we pass \"DummyScheduler\" such as when it is specified in DeepSpeed config.\n                model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(\n                    self.model, self.optimizer, self.lr_scheduler\n                )\n        elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:\n            # In this case we are in DDP + LOMO, which should be supported\n            self.optimizer = self.accelerator.prepare(self.optimizer)\n\n        if self.is_fsdp_enabled:\n            self.model = self.model_wrapped = model\n\n        # for the rest of this function `model` is the outside model, whether it was wrapped or not\n        if model is not self.model:\n            self.model_wrapped = model\n\n        # backward compatibility\n        if self.is_deepspeed_enabled:\n            self.deepspeed = self.model_wrapped\n\n        # ZO2 added ->\n        if delay_optimizer_creation:\n            if self.zo:\n                self.create_optimizer_and_scheduler(num_training_steps=max_steps, model=model)\n            else:\n                self.create_optimizer_and_scheduler(num_training_steps=max_steps)\n\n        # ZO2 added ->\n        # Check if saved optimizer or scheduler states exist\n        if self.zo:\n            _, model = self._load_optimizer_and_scheduler(resume_from_checkpoint, model)\n        else:\n            self._load_optimizer_and_scheduler(resume_from_checkpoint)\n\n        # # ckpt loading\n        # if resume_from_checkpoint is not None:\n        #     if self.is_deepspeed_enabled:\n        #         deepspeed_load_checkpoint(\n        #             self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model)\n        #         )\n        #     elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:\n        #         self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)\n\n        # # Check if saved optimizer or scheduler states exist\n        # self._load_optimizer_and_scheduler(resume_from_checkpoint)\n        # self._load_scaler(resume_from_checkpoint)\n\n        # important: at this point:\n        # self.model         is the Transformers Model\n        # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model),\n        # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.\n\n        # Train!\n        logger.info(\"***** Running training *****\")\n        logger.info(f\"  Num examples = {num_examples:,}\")\n        logger.info(f\"  Num Epochs = {num_train_epochs:,}\")\n        logger.info(f\"  Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}\")\n        if self.args.per_device_train_batch_size != self._train_batch_size:\n            logger.info(f\"  Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}\")\n        logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}\")\n        logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n        logger.info(f\"  Total optimization steps = {max_steps:,}\")\n        logger.info(f\"  Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}\")\n\n        self.state.epoch = 0\n        start_time = time.time()\n        epochs_trained = 0\n        steps_trained_in_current_epoch = 0\n        steps_trained_progress_bar = None\n\n        # Check if continuing training from a checkpoint\n        if resume_from_checkpoint is not None and os.path.isfile(\n            os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)\n        ):\n            self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))\n            self.compare_trainer_and_checkpoint_args(self.args, self.state)\n            self._load_callback_state()\n            epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)\n            if not args.ignore_data_skip:\n                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)\n                steps_trained_in_current_epoch *= args.gradient_accumulation_steps\n            else:\n                steps_trained_in_current_epoch = 0\n\n            logger.info(\"  Continuing training from checkpoint, will skip to saved global_step\")\n            logger.info(f\"  Continuing training from epoch {epochs_trained}\")\n            logger.info(f\"  Continuing training from global step {self.state.global_step}\")\n            if not args.ignore_data_skip:\n                logger.info(\n                    f\"  Will skip the first {epochs_trained} epochs then the first\"\n                    f\" {steps_trained_in_current_epoch} batches in the first epoch.\"\n                )\n\n        # Update the references\n        for attr in (\"model\", \"optimizer\", \"lr_scheduler\"):\n            setattr(self.callback_handler, attr, getattr(self, attr))\n        self.callback_handler.train_dataloader = train_dataloader\n\n        self.state.init_training_references(self, max_steps, num_train_epochs, trial)\n\n        # tr_loss is a tensor to avoid synchronization of TPUs through .item()\n        tr_loss = torch.tensor(0.0, device=args.device)\n        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses\n        self._total_loss_scalar = 0.0\n        self._globalstep_last_logged = self.state.global_step\n        model.zero_grad()\n        grad_norm: Optional[float] = None\n        learning_rate = None\n        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)\n\n        if args.eval_on_start:\n            self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)\n\n        for epoch in range(epochs_trained, num_train_epochs):\n            epoch_dataloader = train_dataloader\n            if hasattr(epoch_dataloader, \"set_epoch\"):\n                epoch_dataloader.set_epoch(epoch)\n\n            # Reset the past mems state at the beginning of each epoch if necessary.\n            if args.past_index >= 0:\n                self._past = None\n\n            steps_in_epoch = (\n                len(epoch_dataloader)\n                if len_dataloader is not None\n                else args.max_steps * args.gradient_accumulation_steps\n            )\n            self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)\n\n            if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:\n                self._load_rng_state(resume_from_checkpoint)\n\n            rng_to_sync = False\n            steps_skipped = 0\n            if steps_trained_in_current_epoch > 0:\n                epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)\n                steps_skipped = steps_trained_in_current_epoch\n                steps_trained_in_current_epoch = 0\n                rng_to_sync = True\n\n            step = -1\n            epoch_iterator = iter(epoch_dataloader)\n            # We chunkify the epoch iterator into gradient accumulation steps `n` batches\n            remainder = num_examples % args.gradient_accumulation_steps\n            if remainder == 0:\n                remainder = args.gradient_accumulation_steps\n            update_step = -1\n            total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1\n            if args.gradient_accumulation_steps == 1:\n                total_updates -= 1\n            for _ in range(total_updates):\n                update_step += 1\n                num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder\n                batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device)\n                for i, inputs in enumerate(batch_samples):\n                    step += 1\n                    do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch\n                    # Since we perform prefetching, we need to manually set sync_gradients\n                    self.accelerator.gradient_state._set_sync_gradients(do_sync_step)\n\n                    if self.args.include_num_input_tokens_seen:\n                        main_input_name = getattr(self.model, \"main_input_name\", \"input_ids\")\n                        if main_input_name not in inputs:\n                            logger.warning(\n                                \"Tried to track the number of tokens seen, however the current model is \"\n                                \"not configured properly to know what item is the input. To fix this, add \"\n                                \"a `main_input_name` attribute to the model class you are using.\"\n                            )\n                        else:\n                            input_tokens = inputs[main_input_name].numel()\n                            input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64)\n                            self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item()\n                    if rng_to_sync:\n                        self._load_rng_state(resume_from_checkpoint)\n                        rng_to_sync = False\n\n                    # Skip past any already trained steps if resuming training\n                    if steps_trained_in_current_epoch > 0:\n                        steps_trained_in_current_epoch -= 1\n                        if steps_trained_progress_bar is not None:\n                            steps_trained_progress_bar.update(1)\n                        if steps_trained_in_current_epoch == 0:\n                            self._load_rng_state(resume_from_checkpoint)\n                        continue\n                    elif steps_trained_progress_bar is not None:\n                        steps_trained_progress_bar.close()\n                        steps_trained_progress_bar = None\n\n                    if step % args.gradient_accumulation_steps == 0:\n                        self.control = self.callback_handler.on_step_begin(args, self.state, self.control)\n\n\n                    # ZO2 added -> estimate gradient and updates\n                    if self.zo:\n                        tr_loss_step = self.zo2_training_step(model, inputs)\n                    else:\n                        # We explicitly want to avoid relying on `accelerator.accumulate` for generation training\n                        context = (\n                            functools.partial(self.accelerator.no_sync, model=model)\n                            if i != len(batch_samples) - 1\n                            and self.accelerator.distributed_type != DistributedType.DEEPSPEED\n                            else contextlib.nullcontext\n                        )\n                        with context():\n                            tr_loss_step = self.training_step(model, inputs, num_items_in_batch)\n\n                    if (\n                        args.logging_nan_inf_filter\n                        and not is_torch_xla_available()\n                        and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))\n                    ):\n                        # if loss is nan or inf simply add the average of previous logged losses\n                        tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)\n                    else:\n                        if tr_loss.device != tr_loss_step.device:\n                            raise ValueError(\n                                f\"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}\"\n                            )\n                        tr_loss = tr_loss + tr_loss_step\n\n                    self.current_flos += float(self.floating_point_ops(inputs))\n\n                    if do_sync_step:\n                        # ZO2 added -> ignore parameter update since it is fuesd with model forward\n                        if self.zo:\n                            pass\n                        else:\n                            # Since we perform prefetching, we need to manually set sync_gradients to True\n                            self.accelerator.gradient_state._set_sync_gradients(True)\n\n                            # Gradient clipping\n                            if args.max_grad_norm is not None and args.max_grad_norm > 0:\n                                if is_sagemaker_mp_enabled() and args.fp16:\n                                    _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)\n                                elif self.use_apex:\n                                    # Revert to normal clipping otherwise, handling Apex or full precision\n                                    _grad_norm = nn.utils.clip_grad_norm_(\n                                        amp.master_params(self.optimizer),\n                                        args.max_grad_norm,\n                                    )\n                                else:\n                                    _grad_norm = self.accelerator.clip_grad_norm_(\n                                        model.parameters(),\n                                        args.max_grad_norm,\n                                    )\n\n                                if (\n                                    is_accelerate_available()\n                                    and self.accelerator.distributed_type == DistributedType.DEEPSPEED\n                                ):\n                                    grad_norm = model.get_global_grad_norm()\n                                    # In some cases the grad norm may not return a float\n                                    if hasattr(grad_norm, \"item\"):\n                                        grad_norm = grad_norm.item()\n                                else:\n                                    grad_norm = _grad_norm\n\n                            self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)\n\n                            self.optimizer.step()\n\n                            self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)\n\n                            # get leaning rate before update\n                            learning_rate = self._get_learning_rate()\n\n                            if not self.accelerator.optimizer_step_was_skipped:\n                                # Delay optimizer scheduling until metrics are generated\n                                if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):\n                                    self.lr_scheduler.step()\n\n                            model.zero_grad()\n\n                        self.state.global_step += 1\n                        self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch\n                        self.control = self.callback_handler.on_step_end(args, self.state, self.control)\n                        self._maybe_log_save_evaluate(\n                            tr_loss,\n                            grad_norm,\n                            model,\n                            trial,\n                            epoch,\n                            ignore_keys_for_eval,\n                            start_time,\n                            learning_rate=learning_rate,\n                        )\n                    else:\n                        self.control = self.callback_handler.on_substep_end(args, self.state, self.control)\n\n                    # PyTorch/XLA relies on the data loader to insert the mark_step for\n                    # each step. Since we are breaking the loop early, we need to manually\n                    # insert the mark_step here.\n                    if self.control.should_epoch_stop or self.control.should_training_stop:\n                        if is_torch_xla_available():\n                            xm.mark_step()\n                        break\n                # We also need to break out of the nested loop\n                if self.control.should_epoch_stop or self.control.should_training_stop:\n                    if is_torch_xla_available():\n                        xm.mark_step()\n                    break\n            \n            if step < 0:\n                logger.warning(\n                    \"There seems not to be a single sample in your epoch_iterator, stopping training at step\"\n                    f\" {self.state.global_step}! This is expected if you're using an IterableDataset and set\"\n                    f\" num_steps ({max_steps}) higher than the number of available samples.\"\n                )\n                self.control.should_training_stop = True\n\n            self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)\n            self._maybe_log_save_evaluate(\n                tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=learning_rate\n            )\n\n            if DebugOption.TPU_METRICS_DEBUG in self.args.debug:\n                if is_torch_xla_available():\n                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)\n                    xm.master_print(met.metrics_report())\n                else:\n                    logger.warning(\n                        \"You enabled PyTorch/XLA debug metrics but you don't have a TPU \"\n                        \"configured. Check your training configuration if this is unexpected.\"\n                    )\n            if self.control.should_training_stop:\n                break\n\n        if args.past_index and hasattr(self, \"_past\"):\n            # Clean the state at the end of training\n            delattr(self, \"_past\")\n\n        logger.info(\"\\n\\nTraining completed. Do not forget to share your model on huggingface.co/models =)\\n\\n\")\n        if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:\n            # Wait for everyone to get here so we are sure the model has been saved by process 0.\n            if is_torch_xla_available():\n                xm.rendezvous(\"load_best_model_at_end\")\n            elif args.parallel_mode == ParallelMode.DISTRIBUTED:\n                dist.barrier()\n            elif is_sagemaker_mp_enabled():\n                smp.barrier()\n\n            self._load_best_model()\n\n        # add remaining tr_loss\n        self._total_loss_scalar += tr_loss.item()\n        effective_global_step = max(self.state.global_step, 0.001)  # Avoid ZeroDivisionError\n        train_loss = self._total_loss_scalar / effective_global_step\n\n        metrics = speed_metrics(\n            \"train\",\n            start_time,\n            num_samples=num_train_samples,\n            num_steps=self.state.max_steps,\n            num_tokens=num_train_tokens,\n        )\n        self.store_flos()\n        metrics[\"total_flos\"] = self.state.total_flos\n        metrics[\"train_loss\"] = train_loss\n\n        self.is_in_train = False\n\n        self._memory_tracker.stop_and_update_metrics(metrics)\n\n        self.log(metrics)\n\n        run_dir = self._get_output_dir(trial)\n        checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)\n\n        # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.\n        if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:\n            for checkpoint in checkpoints_sorted:\n                if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):\n                    logger.info(f\"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit\")\n                    shutil.rmtree(checkpoint, ignore_errors=True)\n\n        self.control = self.callback_handler.on_train_end(args, self.state, self.control)\n\n        # Wait for the checkpoint to be uploaded.\n        self._finish_current_push()\n\n        # After training we make sure to retrieve back the original forward pass method\n        # for the embedding layer by removing the forward post hook.\n        if self.neftune_noise_alpha is not None:\n            self._deactivate_neftune(self.model)\n\n        return TrainOutput(self.state.global_step, train_loss, metrics)\n\n\n    def _load_optimizer_and_scheduler(self, checkpoint, model=None):\n        \"\"\"\n        disable the optimizer resume.\n        \"\"\"\n        output = super()._load_optimizer_and_scheduler(checkpoint)\n        if self.zo and model is not None:\n            model.opt = self.optimizer\n            return output, model\n        return output\n\n    def create_optimizer_and_scheduler(self, num_training_steps: int, model: nn.Module=None):\n        \"\"\"\n        disable the optimizer but leave the learning rate scheduler.\n        \"\"\"\n        if not self.zo:\n            self.create_optimizer()\n            if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:\n                # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer\n                optimizer = self.optimizer.optimizer\n            else:\n                optimizer = self.optimizer\n        else:\n            if model is None:\n                optimizer = self.optimizer = self.model.opt\n            else:\n                optimizer = self.optimizer = model.opt\n        self.create_scheduler(num_training_steps, optimizer)\n\n    def _move_model_to_device(self, model, device):\n        pass\n\n    #*********************** zo2 functions ***********************#\n\n    def _zo2_unsupported_conditions(self, args):\n        if args.gradient_accumulation_steps > 1:\n            raise NotImplementedError\n        if args.n_gpu > 1:\n            raise NotImplementedError(\"Currently ZO2 only support one working device\")\n        if args.deepspeed:\n            raise NotImplementedError\n        if is_sagemaker_mp_enabled():\n            raise NotImplementedError\n        if args.torch_compile:\n            raise NotImplementedError\n\n    def register_zo2_training_step_pre_hook(self, hook_fn):\n        \"\"\"\n        example:\n            def print_zo_info(model, inputs):\n                tqdm.write(\"projected grad: {}\".format(model.opt.projected_grad))\n                return model, inputs\n            trainer = ZOTrainer(...)\n            trainer.register_zo2_training_step_pre_hook(print_zo_info)\n        \"\"\"\n        self.zo2_training_step_pre_hooks.append(hook_fn)\n\n    def register_zo2_training_step_post_hook(self, hook_fn):\n        \"\"\"\n        example:\n            def drop_invalid_data(model, inputs, loss):\n                # Extract projected_grad, handle both tensor and scalar cases\n                projected_grad = model.opt.projected_grad\n                if isinstance(projected_grad, torch.Tensor):\n                    projected_grad_is_nan = torch.isnan(projected_grad).any()\n                else:\n                    projected_grad_is_nan = projected_grad != projected_grad  # Check for NaN in scalars\n                if torch.isnan(loss) or projected_grad_is_nan:\n                    tqdm.write(\"'loss': {} or 'projected_grad': {} is nan. Drop this step.\".format(\n                        loss, model.opt.projected_grad\n                    ))\n                    model.opt.projected_grad = 0  # Reset projected_grad to prevent parameter updates\n                return model, inputs, loss\n            trainer = ZOTrainer(...)\n            trainer.register_zo2_training_step_post_hook(drop_invalid_data)\n        \"\"\"\n        self.zo2_training_step_post_hooks.append(hook_fn)\n\n    def zo2_training_step(self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:\n        if self.zo2_training_step_pre_hooks != []:\n            for pre_hook_fn in self.zo2_training_step_pre_hooks:\n                model, inputs = pre_hook_fn(model, inputs)\n        model.zo_train()\n        inputs = self._prepare_inputs(inputs)\n        loss = model(**inputs)\n        model.zo_eval()\n        if self.zo2_training_step_post_hooks != []:\n            for post_hook_fn in self.zo2_training_step_post_hooks:\n                model, inputs, loss = post_hook_fn(model, inputs, loss)\n        return loss\n    "
  },
  {
    "path": "zo2/trainer/hf_trl/__init__.py",
    "content": "from .sft_trainer import ZOSFTTrainer"
  },
  {
    "path": "zo2/trainer/hf_trl/sft_trainer.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom trl import SFTTrainer\n\nimport contextlib\nimport functools\nimport glob\nimport inspect\nimport math\nimport os\nimport random\nimport re\nimport shutil\nimport sys\nimport time\nimport warnings\nfrom collections.abc import Mapping\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union\nimport copy\nimport numpy as np\nfrom functools import wraps\n\nfrom tqdm.auto import tqdm\nfrom transformers import Trainer, DataCollator\nfrom sklearn.linear_model import LinearRegression, LogisticRegression, LogisticRegressionCV\n\n# Integrations must be imported before ML frameworks:\nfrom transformers.integrations import (  # isort: split\n    default_hp_search_backend,\n    get_reporting_integration_callbacks,\n    hp_params,\n    is_fairscale_available,\n    is_optuna_available,\n    is_ray_tune_available,\n    is_sigopt_available,\n    is_wandb_available,\n    run_hp_search_optuna,\n    run_hp_search_ray,\n    run_hp_search_sigopt,\n    run_hp_search_wandb,\n)\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nfrom packaging import version\nfrom torch import nn\nfrom torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler\nfrom torch.utils.data.distributed import DistributedSampler\n\nfrom huggingface_hub import Repository\n\nfrom transformers import __version__\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator\nfrom transformers.debug_utils import DebugOption, DebugUnderflowOverflow\nfrom transformers.deepspeed import deepspeed_init, is_deepspeed_zero3_enabled\nfrom transformers.dependency_versions_check import dep_version_check\nfrom transformers.modelcard import TrainingSummary\nfrom transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model\nfrom transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES\nfrom transformers.optimization import Adafactor, get_scheduler\nfrom transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10, is_torch_less_than_1_11\nfrom transformers.tokenization_utils_base import PreTrainedTokenizerBase\nfrom transformers.trainer_callback import (\n    CallbackHandler,\n    DefaultFlowCallback,\n    PrinterCallback,\n    ProgressCallback,\n    TrainerCallback,\n    TrainerControl,\n    TrainerState,\n)\nfrom transformers.trainer_pt_utils import (\n    DistributedLengthGroupedSampler,\n    DistributedSamplerWithLoop,\n    DistributedTensorGatherer,\n    IterableDatasetShard,\n    LabelSmoother,\n    LengthGroupedSampler,\n    SequentialDistributedSampler,\n    ShardSampler,\n    distributed_broadcast_scalars,\n    distributed_concat,\n    find_batch_size,\n    get_module_class_from_name,\n    get_parameter_names,\n    nested_concat,\n    nested_detach,\n    nested_numpify,\n    nested_truncate,\n    nested_xla_mesh_reduce,\n    reissue_pt_warnings,\n)\nfrom transformers.trainer_utils import (\n    PREFIX_CHECKPOINT_DIR,\n    BestRun,\n    EvalLoopOutput,\n    EvalPrediction,\n    FSDPOption,\n    HPSearchBackend,\n    HubStrategy,\n    IntervalStrategy,\n    PredictionOutput,\n    RemoveColumnsCollator,\n    ShardedDDPOption,\n    TrainerMemoryTracker,\n    TrainOutput,\n    default_compute_objective,\n    default_hp_space,\n    denumpify_detensorize,\n    enable_full_determinism,\n    find_executable_batch_size,\n    get_last_checkpoint,\n    has_length,\n    number_of_arguments,\n    seed_worker,\n    set_seed,\n    speed_metrics,\n)\nfrom transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments\nfrom transformers.utils import (\n    CONFIG_NAME,\n    WEIGHTS_INDEX_NAME,\n    WEIGHTS_NAME,\n    find_labels,\n    get_full_repo_name,\n    is_apex_available,\n    is_datasets_available,\n    is_in_notebook,\n    is_ipex_available,\n    is_sagemaker_dp_enabled,\n    is_sagemaker_mp_enabled,\n    is_torch_tensorrt_fx_available,\n    is_torch_tpu_available,\n    is_accelerate_available,\n    is_torchdynamo_available,\n    logging,\n)\nfrom transformers.utils.generic import ContextManagers\nfrom transformers.trainer_pt_utils import (\n    _get_learning_rate, \n    log_metrics, \n    metrics_format, \n    save_metrics, \n    save_state,\n    get_model_param_count,\n)\n\n_is_native_cpu_amp_available = is_torch_greater_or_equal_than_1_10\n\nDEFAULT_CALLBACKS = [DefaultFlowCallback]\nDEFAULT_PROGRESS_CALLBACK = ProgressCallback\n\nif is_in_notebook():\n    from transformers.utils.notebook import NotebookProgressCallback\n\n    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback\n\nif is_apex_available():\n    from apex import amp\n\nif is_datasets_available():\n    import datasets\n\nif is_torch_tpu_available(check_device=False):\n    import torch_xla.core.xla_model as xm\n    import torch_xla.debug.metrics as met\n    import torch_xla.distributed.parallel_loader as pl\n\nif is_fairscale_available():\n    dep_version_check(\"fairscale\")\n    import fairscale\n    from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP\n    from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP\n    from fairscale.nn.wrap import auto_wrap\n    from fairscale.optim import OSS\n    from fairscale.optim.grad_scaler import ShardedGradScaler\n\n\nif is_sagemaker_mp_enabled():\n    import smdistributed.modelparallel.torch as smp\n    from smdistributed.modelparallel import __version__ as SMP_VERSION\n\n    IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse(\"1.10\")\n\n    from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat\nelse:\n    IS_SAGEMAKER_MP_POST_1_10 = False\n\n\nskip_first_batches = None\nif is_accelerate_available():\n    from accelerate import __version__ as accelerate_version\n\n    if version.parse(accelerate_version) >= version.parse(\"0.16\"):\n        from accelerate import skip_first_batches\n\n\nif TYPE_CHECKING:\n    import optuna\n\nlogger = logging.get_logger(__name__)\n\n\n# Name of the files used for checkpointing\nTRAINING_ARGS_NAME = \"training_args.bin\"\nTRAINER_STATE_NAME = \"trainer_state.json\"\nOPTIMIZER_NAME = \"optimizer.pt\"\nSCHEDULER_NAME = \"scheduler.pt\"\nSCALER_NAME = \"scaler.pt\"\n\n\nclass ZOSFTTrainer(SFTTrainer):\n    \n    def __init__(\n        self,\n        model: Union[PreTrainedModel, nn.Module, str],\n        args: TrainingArguments = None,\n        data_collator: Optional[DataCollator] = None,\n        train_dataset: Optional[Dataset] = None,\n        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,\n        tokenizer: Optional[PreTrainedTokenizerBase] = None,\n        model_init: Optional[Callable[[], PreTrainedModel]] = None,\n        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,\n        callbacks: Optional[List[TrainerCallback]] = None,\n        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),\n        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,\n        peft_config: Optional[\"PeftConfig\"] = None,\n        dataset_text_field: Optional[str] = None,\n        packing: Optional[bool] = False,\n        formatting_func: Optional[Callable] = None,\n        max_seq_length: Optional[int] = None,\n        infinite: Optional[bool] = False,\n        num_of_sequences: Optional[int] = 1024,\n        chars_per_token: Optional[float] = 3.6,\n        dataset_num_proc: Optional[int] = None,\n        dataset_batch_size: int = 1000,\n        neftune_noise_alpha: Optional[float] = None,\n        model_init_kwargs: Optional[Dict] = None,\n    ):\n        # ZO2 added: if using ZO2:\n        if hasattr(model, \"zo_training\"):\n            print(\"ZO training mode is enabled.\")\n            self.zo = True\n        else:\n            self.zo = False\n\n        # ZO2 added: currently unsupported conditions\n        if self.zo:\n            self._zo2_unsupported_conditions(args)\n        \n        # ZO2 added: init hooks buffer\n        if self.zo:\n            self.zo2_training_step_pre_hooks = []\n            self.zo2_training_step_post_hooks = []\n\n        super().__init__(model, args, data_collator, train_dataset, eval_dataset,\n                         tokenizer, model_init, compute_metrics, callbacks,\n                         optimizers, preprocess_logits_for_metrics, peft_config,\n                         dataset_text_field, packing, formatting_func,\n                         max_seq_length, infinite, num_of_sequences, \n                         chars_per_token, dataset_num_proc, dataset_batch_size,\n                         neftune_noise_alpha, model_init_kwargs)\n\n    def _inner_training_loop(self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None):\n        \"\"\"\n        We overload the original training loop to add ZO2. Search key word \"ZO2 added\"\n        for those updates.\n        \"\"\"\n\n        self._train_batch_size = batch_size\n        # Data loader and number of training steps\n        train_dataloader = self.get_train_dataloader()\n\n        # Setting up training control variables:\n        # number of training epochs: num_train_epochs\n        # number of training steps per epoch: num_update_steps_per_epoch\n        # total number of training steps to execute: max_steps\n        total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size\n\n        len_dataloader = None\n        if has_length(train_dataloader):\n            len_dataloader = len(train_dataloader)\n            num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps\n            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)\n            num_examples = self.num_examples(train_dataloader)\n            if args.max_steps > 0:\n                max_steps = args.max_steps\n                num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(\n                    args.max_steps % num_update_steps_per_epoch > 0\n                )\n                # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's\n                # the best we can do.\n                num_train_samples = args.max_steps * total_train_batch_size\n            else:\n                max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)\n                num_train_epochs = math.ceil(args.num_train_epochs)\n                num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs\n        elif args.max_steps > 0:  # Rely on max_steps when dataloader does not have a working size\n            max_steps = args.max_steps\n            # Setting a very large number of epochs so we go as many times as necessary over the iterator.\n            num_train_epochs = sys.maxsize\n            num_update_steps_per_epoch = max_steps\n            num_examples = total_train_batch_size * args.max_steps\n            num_train_samples = args.max_steps * total_train_batch_size\n        else:\n            raise ValueError(\n                \"args.max_steps must be set to a positive value if dataloader does not have a length, was\"\n                f\" {args.max_steps}\"\n            )\n\n        if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:\n            if self.args.n_gpu > 1:\n                # nn.DataParallel(model) replicates the model, creating new variables and module\n                # references registered here no longer work on other gpus, breaking the module\n                raise ValueError(\n                    \"Currently --debug underflow_overflow is not supported under DP. Please use DDP\"\n                    \" (torch.distributed.launch).\"\n                )\n            else:\n                debug_overflow = DebugUnderflowOverflow(self.model)  # noqa\n\n        delay_optimizer_creation = (\n            self.sharded_ddp is not None\n            and self.sharded_ddp != ShardedDDPOption.SIMPLE\n            or is_sagemaker_mp_enabled()\n            or self.fsdp is not None\n        )\n        if args.deepspeed:\n            deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(\n                self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint\n            )\n            self.model = deepspeed_engine.module\n            self.model_wrapped = deepspeed_engine\n            self.deepspeed = deepspeed_engine\n            self.optimizer = optimizer\n            self.lr_scheduler = lr_scheduler\n        elif not delay_optimizer_creation:\n            self.create_optimizer_and_scheduler(num_training_steps=max_steps)\n\n        self.state = TrainerState()\n        self.state.is_hyper_param_search = trial is not None\n\n        # Activate gradient checkpointing if needed\n        if args.gradient_checkpointing:\n            self.model.gradient_checkpointing_enable()\n\n        # ZO2 added ->\n        # model = self._wrap_model(self.model_wrapped)\n        model = self.model\n\n        if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:\n            self._load_from_checkpoint(resume_from_checkpoint, model)\n\n        # for the rest of this function `model` is the outside model, whether it was wrapped or not\n        if model is not self.model:\n            self.model_wrapped = model\n\n        # ZO2 added ->\n        if delay_optimizer_creation:\n            if self.zo:\n                self.create_optimizer_and_scheduler(num_training_steps=max_steps, model=model)\n            else:\n                self.create_optimizer_and_scheduler(num_training_steps=max_steps)\n\n        # ZO2 added ->\n        # Check if saved optimizer or scheduler states exist\n        if self.zo:\n            _, model = self._load_optimizer_and_scheduler(resume_from_checkpoint, model)\n        else:\n            self._load_optimizer_and_scheduler(resume_from_checkpoint)\n\n        # important: at this point:\n        # self.model         is the Transformers Model\n        # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.\n\n        # Train!\n        logger.info(\"***** Running training *****\")\n        logger.info(f\"  Num examples = {num_examples:,}\")\n        logger.info(f\"  Num Epochs = {num_train_epochs:,}\")\n        logger.info(f\"  Instantaneous batch size per device = {args.per_device_train_batch_size:,}\")\n        logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}\")\n        logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n        logger.info(f\"  Total optimization steps = {max_steps:,}\")\n        logger.info(f\"  Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}\")\n\n        self.state.epoch = 0\n        start_time = time.time()\n        epochs_trained = 0\n        steps_trained_in_current_epoch = 0\n        steps_trained_progress_bar = None\n\n        # Check if continuing training from a checkpoint\n        if resume_from_checkpoint is not None and os.path.isfile(\n            os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)\n        ):\n            self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))\n            epochs_trained = self.state.global_step // num_update_steps_per_epoch\n            if not args.ignore_data_skip:\n                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)\n                steps_trained_in_current_epoch *= args.gradient_accumulation_steps\n            else:\n                steps_trained_in_current_epoch = 0\n\n            logger.info(\"  Continuing training from checkpoint, will skip to saved global_step\")\n            logger.info(f\"  Continuing training from epoch {epochs_trained}\")\n            logger.info(f\"  Continuing training from global step {self.state.global_step}\")\n            if not args.ignore_data_skip:\n                if skip_first_batches is None:\n                    logger.info(\n                        f\"  Will skip the first {epochs_trained} epochs then the first\"\n                        f\" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time,\"\n                        \" you can install the latest version of Accelerate with `pip install -U accelerate`.You can\"\n                        \" also add the `--ignore_data_skip` flag to your launch command, but you will resume the\"\n                        \" training on data already seen by your model.\"\n                    )\n                else:\n                    logger.info(\n                        f\"  Will skip the first {epochs_trained} epochs then the first\"\n                        f\" {steps_trained_in_current_epoch} batches in the first epoch.\"\n                    )\n                if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None:\n                    steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)\n                    steps_trained_progress_bar.set_description(\"Skipping the first batches\")\n\n        # Update the references\n        self.callback_handler.model = self.model\n        self.callback_handler.optimizer = self.optimizer\n        self.callback_handler.lr_scheduler = self.lr_scheduler\n        self.callback_handler.train_dataloader = train_dataloader\n        if self.hp_name is not None and self._trial is not None:\n            # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial\n            # parameter to Train when using DDP.\n            self.state.trial_name = self.hp_name(self._trial)\n        if trial is not None:\n            assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial\n            self.state.trial_params = hp_params(assignments)\n        else:\n            self.state.trial_params = None\n        # This should be the same if the state has been saved but in case the training arguments changed, it's safer\n        # to set this after the load.\n        self.state.max_steps = max_steps\n        self.state.num_train_epochs = num_train_epochs\n        self.state.is_local_process_zero = self.is_local_process_zero()\n        self.state.is_world_process_zero = self.is_world_process_zero()\n\n        # tr_loss is a tensor to avoid synchronization of TPUs through .item()\n        tr_loss = torch.tensor(0.0).to(args.device)\n        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses\n        self._total_loss_scalar = 0.0\n        self._globalstep_last_logged = self.state.global_step\n        model.zero_grad()\n\n        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)\n\n        # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.\n        if not args.ignore_data_skip:\n            for epoch in range(epochs_trained):\n                is_random_sampler = hasattr(train_dataloader, \"sampler\") and isinstance(\n                    train_dataloader.sampler, RandomSampler\n                )\n                if is_torch_less_than_1_11 or not is_random_sampler:\n                    # We just need to begin an iteration to create the randomization of the sampler.\n                    # That was before PyTorch 1.11 however...\n                    for _ in train_dataloader:\n                        break\n                else:\n                    # Otherwise we need to call the whooooole sampler cause there is some random operation added\n                    # AT THE VERY END!\n                    _ = list(train_dataloader.sampler)\n\n        total_batched_samples = 0\n        for epoch in range(epochs_trained, num_train_epochs):\n            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):\n                train_dataloader.sampler.set_epoch(epoch)\n            elif hasattr(train_dataloader, \"dataset\") and isinstance(train_dataloader.dataset, IterableDatasetShard):\n                train_dataloader.dataset.set_epoch(epoch)\n\n            if is_torch_tpu_available():\n                parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device)\n                epoch_iterator = parallel_loader\n            else:\n                epoch_iterator = train_dataloader\n\n            # Reset the past mems state at the beginning of each epoch if necessary.\n            if args.past_index >= 0:\n                self._past = None\n\n            steps_in_epoch = (\n                len(epoch_iterator)\n                if len_dataloader is not None\n                else args.max_steps * args.gradient_accumulation_steps\n            )\n            self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)\n\n            if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:\n                self._load_rng_state(resume_from_checkpoint)\n\n            rng_to_sync = False\n            steps_skipped = 0\n            if skip_first_batches is not None and steps_trained_in_current_epoch > 0:\n                epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)\n                steps_skipped = steps_trained_in_current_epoch\n                steps_trained_in_current_epoch = 0\n                rng_to_sync = True\n\n            step = -1\n            for step, inputs in enumerate(epoch_iterator):\n                total_batched_samples += 1\n                if rng_to_sync:\n                    self._load_rng_state(resume_from_checkpoint)\n                    rng_to_sync = False\n\n                # Skip past any already trained steps if resuming training\n                if steps_trained_in_current_epoch > 0:\n                    steps_trained_in_current_epoch -= 1\n                    if steps_trained_progress_bar is not None:\n                        steps_trained_progress_bar.update(1)\n                    if steps_trained_in_current_epoch == 0:\n                        self._load_rng_state(resume_from_checkpoint)\n                    continue\n                elif steps_trained_progress_bar is not None:\n                    steps_trained_progress_bar.close()\n                    steps_trained_progress_bar = None\n\n                if step % args.gradient_accumulation_steps == 0:\n                    self.control = self.callback_handler.on_step_begin(args, self.state, self.control)\n\n                # ZO2 added -> estimate gradient and updates\n                if self.zo:\n                    tr_loss_step = self.zo2_training_step(model, inputs)\n                else:\n                    if (\n                        (total_batched_samples % args.gradient_accumulation_steps != 0)\n                        and args.local_rank != -1\n                        and args._no_sync_in_gradient_accumulation\n                    ):\n                        # Avoid unnecessary DDP synchronization since there will be no backward pass on this example.\n                        with model.no_sync():\n                            tr_loss_step = self.training_step(model, inputs)\n                    else:\n                        tr_loss_step = self.training_step(model, inputs)\n\n                if (\n                    args.logging_nan_inf_filter\n                    and not is_torch_tpu_available()\n                    and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))\n                ):\n                    # if loss is nan or inf simply add the average of previous logged losses\n                    tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)\n                else:\n                    tr_loss += tr_loss_step\n\n                self.current_flos += float(self.floating_point_ops(inputs))\n\n                # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps\n                if self.deepspeed:\n                    self.deepspeed.step()\n\n                if total_batched_samples % args.gradient_accumulation_steps == 0 or (\n                    # last step in epoch but step is always smaller than gradient_accumulation_steps\n                    steps_in_epoch <= args.gradient_accumulation_steps\n                    and (step + 1) == steps_in_epoch\n                ):\n                    # ZO2 added -> ignore parameter update since it is fuesd with model forward\n                    if self.zo:\n                        pass\n                    else:\n                        # Gradient clipping\n                        if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed:\n                            # deepspeed does its own clipping\n\n                            if self.do_grad_scaling:\n                                # Reduce gradients first for XLA\n                                if is_torch_tpu_available():\n                                    gradients = xm._fetch_gradients(self.optimizer)\n                                    xm.all_reduce(\"sum\", gradients, scale=1.0 / xm.xrt_world_size())\n                                # AMP: gradients need unscaling\n                                self.scaler.unscale_(self.optimizer)\n\n                            if is_sagemaker_mp_enabled() and args.fp16:\n                                self.optimizer.clip_master_grads(args.max_grad_norm)\n                            elif hasattr(self.optimizer, \"clip_grad_norm\"):\n                                # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping\n                                self.optimizer.clip_grad_norm(args.max_grad_norm)\n                            elif hasattr(model, \"clip_grad_norm_\"):\n                                # Some models (like FullyShardedDDP) have a specific way to do gradient clipping\n                                model.clip_grad_norm_(args.max_grad_norm)\n                            else:\n                                # Revert to normal clipping otherwise, handling Apex or full precision\n                                nn.utils.clip_grad_norm_(\n                                    amp.master_params(self.optimizer) if self.use_apex else model.parameters(),\n                                    args.max_grad_norm,\n                                )\n\n                        # Optimizer step\n                        optimizer_was_run = True\n                        if self.deepspeed:\n                            pass  # called outside the loop\n                        elif is_torch_tpu_available():\n                            if self.do_grad_scaling:\n                                self.scaler.step(self.optimizer)\n                                self.scaler.update()\n                            else:\n                                xm.optimizer_step(self.optimizer)\n                        elif self.do_grad_scaling:\n                            scale_before = self.scaler.get_scale()\n                            self.scaler.step(self.optimizer)\n                            self.scaler.update()\n                            scale_after = self.scaler.get_scale()\n                            optimizer_was_run = scale_before <= scale_after\n                        else:\n                            self.optimizer.step()\n\n                        if optimizer_was_run and not self.deepspeed:\n                            self.lr_scheduler.step()\n\n                        model.zero_grad()\n\n                    self.state.global_step += 1\n                    self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch\n                    self.control = self.callback_handler.on_step_end(args, self.state, self.control)\n\n                    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)\n                else:\n                    self.control = self.callback_handler.on_substep_end(args, self.state, self.control)\n\n                if self.control.should_epoch_stop or self.control.should_training_stop:\n                    break\n            if step < 0:\n                logger.warning(\n                    \"There seems to be not a single sample in your epoch_iterator, stopping training at step\"\n                    f\" {self.state.global_step}! This is expected if you're using an IterableDataset and set\"\n                    f\" num_steps ({max_steps}) higher than the number of available samples.\"\n                )\n                self.control.should_training_stop = True\n\n            self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)\n            self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)\n\n            if DebugOption.TPU_METRICS_DEBUG in self.args.debug:\n                if is_torch_tpu_available():\n                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)\n                    xm.master_print(met.metrics_report())\n                else:\n                    logger.warning(\n                        \"You enabled PyTorch/XLA debug metrics but you don't have a TPU \"\n                        \"configured. Check your training configuration if this is unexpected.\"\n                    )\n            if self.control.should_training_stop:\n                break\n\n        if args.past_index and hasattr(self, \"_past\"):\n            # Clean the state at the end of training\n            delattr(self, \"_past\")\n\n        logger.info(\"\\n\\nTraining completed. Do not forget to share your model on huggingface.co/models =)\\n\\n\")\n        if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:\n            # Wait for everyone to get here so we are sur the model has been saved by process 0.\n            if is_torch_tpu_available():\n                xm.rendezvous(\"load_best_model_at_end\")\n            elif args.local_rank != -1:\n                dist.barrier()\n            elif is_sagemaker_mp_enabled():\n                smp.barrier()\n\n            self._load_best_model()\n\n        # add remaining tr_loss\n        self._total_loss_scalar += tr_loss.item()\n        train_loss = self._total_loss_scalar / self.state.global_step\n\n        metrics = speed_metrics(\"train\", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)\n        self.store_flos()\n        metrics[\"total_flos\"] = self.state.total_flos\n        metrics[\"train_loss\"] = train_loss\n\n        self.is_in_train = False\n\n        self._memory_tracker.stop_and_update_metrics(metrics)\n\n        self.log(metrics)\n\n        run_dir = self._get_output_dir(trial)\n        checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)\n\n        # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.\n        if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:\n            for checkpoint in checkpoints_sorted:\n                if checkpoint != self.state.best_model_checkpoint:\n                    logger.info(f\"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit\")\n                    shutil.rmtree(checkpoint)\n\n        self.control = self.callback_handler.on_train_end(args, self.state, self.control)\n\n        return TrainOutput(self.state.global_step, train_loss, metrics)\n    \n\n    @wraps(Trainer.train)\n    def train(self, *args, **kwargs):\n        \"\"\"\n            ZO2 does not support neftune.\n        \"\"\"\n        # # Activate neftune right before training.\n        # if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune:\n        #     self.model = self._trl_activate_neftune(self.model)\n\n        output = Trainer.train(self, *args, **kwargs)\n\n        # # After training we make sure to retrieve back the original forward pass method\n        # # for the embedding layer by removing the forward post hook.\n        # if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune:\n        #     unwrapped_model = unwrap_model(self.model)\n        #     if is_peft_available() and isinstance(unwrapped_model, PeftModel):\n        #         embeddings = unwrapped_model.base_model.model.get_input_embeddings()\n        #     else:\n        #         embeddings = unwrapped_model.get_input_embeddings()\n\n        #     self.neftune_hook_handle.remove()\n        #     del embeddings.neftune_noise_alpha\n\n        return output\n\n\n    def _load_optimizer_and_scheduler(self, checkpoint, model=None):\n        \"\"\"\n        disable the optimizer resume.\n        \"\"\"\n        output = super()._load_optimizer_and_scheduler(checkpoint)\n        if self.zo and model is not None:\n            model.opt = self.optimizer\n            return output, model\n        return output\n\n    def create_optimizer_and_scheduler(self, num_training_steps: int, model: nn.Module=None):\n        \"\"\"\n        disable the optimizer but leave the learning rate scheduler.\n        \"\"\"\n        if not self.zo:\n            self.create_optimizer()\n            if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:\n                # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer\n                optimizer = self.optimizer.optimizer\n            else:\n                optimizer = self.optimizer\n        else:\n            if model is None:\n                optimizer = self.optimizer = self.model.opt\n            else:\n                optimizer = self.optimizer = model.opt\n        self.create_scheduler(num_training_steps, optimizer)\n\n    def _move_model_to_device(self, model, device):\n        pass\n\n    #*********************** zo2 functions ***********************#\n\n    def _zo2_unsupported_conditions(self, args):\n        if args.gradient_accumulation_steps > 1:\n            raise NotImplementedError\n        if args.n_gpu > 1:\n            raise NotImplementedError(\"Currently ZO2 only support one working device\")\n        if args.deepspeed:\n            raise NotImplementedError\n        if is_torch_tpu_available(check_device=False):\n            raise NotImplementedError\n        if is_fairscale_available():\n            raise NotImplementedError\n        if is_sagemaker_mp_enabled():\n            raise NotImplementedError\n        if args.torch_compile:\n            raise NotImplementedError\n\n    def register_zo2_training_step_pre_hook(self, hook_fn):\n        self.zo2_training_step_pre_hooks.append(hook_fn)\n\n    def register_zo2_training_step_post_hook(self, hook_fn):\n        self.zo2_training_step_post_hooks.append(hook_fn)\n\n    def zo2_training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:\n        if self.zo2_training_step_pre_hooks != []:\n            for pre_hook_fn in self.zo2_training_step_pre_hooks:\n                model, inputs = pre_hook_fn(model, inputs)\n        model.zo_train()\n        inputs = self._prepare_inputs(inputs)\n        loss = model(**inputs)\n        model.zo_eval()\n        if self.zo2_training_step_post_hooks != []:\n            for post_hook_fn in self.zo2_training_step_post_hooks:\n                model, inputs, loss = post_hook_fn(model, inputs, loss)\n        return loss\n    "
  },
  {
    "path": "zo2/utils/__init__.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nfrom .utils import seed_everything"
  },
  {
    "path": "zo2/utils/utils.py",
    "content": "# Copyright (c) 2025 liangyuwang\n# Licensed under the Apache License, Version 2.0\n\nfrom torch import nn\nimport torch\nimport os\nimport random\nimport numpy as np\n\ndef print_all(module: nn.Module, inputs, outputs):\n    print(\"Param: \")\n    for p in module.parameters():\n        print(p.min().item(), p.max().item(), p.mean().item())\n    print(\"Inputs: \")\n    if isinstance(inputs, torch.Tensor):\n        print(inputs.min().item(), inputs.max().item())\n    else:\n        for _, input in inputs.items():\n            if isinstance(input, torch.Tensor):\n                print(input.min().item(), input.max().item())\n    print(\"Output: \")\n    if isinstance(outputs, torch.Tensor):\n        print(outputs.min().item(), outputs.max().item(), outputs.mean().item())\n    else:\n        print(\"Unrecongized outputs.\")\n    print(\"*\" * 20)\n        \n\ndef print_hook(module, input, output):\n    print(module, f\"{module.weight.min().item():.4f}, {module.weight.max().item():.4f}\")\n    print(f\"{output.min().item():.8f} {output.max().item():.8f} {output.mean().item():.8f}\")\n\ndef print_para_and_device(model):\n    for p, v in model.named_parameters():\n        print(f\"{p}: {v.device}\")\n\ndef cal_self_reg_loss(logits, labels):\n    loss = nn.CrossEntropyLoss()(\n        logits[:, :-1, :].reshape(-1, logits.size(-1)), \n        labels[:, 1:].reshape(-1)\n    )\n    return loss\n\ndef seed_everything(seed):\n    random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n"
  }
]